🤖 AI Summary
Existing deep test-time memory RNNs (e.g., Titans, TTT) offer linear scalability but suffer from prohibitively slow training and low hardware utilization, limiting practical deployment. Their parallelization is fundamentally constrained by a trade-off inherent in block size selection: large blocks accelerate training yet degrade model performance, whereas small blocks preserve accuracy at the cost of severe computational inefficiency. This work introduces the TNT training paradigm, featuring a novel two-stage decoupled design. Stage I enables context-parallel long-range modeling via a hierarchical memory architecture with periodic local state resets. Stage II separates pretraining and fine-tuning: global coarse-grained processing with large blocks is followed by high-resolution fine-tuning with small blocks. Evaluated on Titans and TTT, TNT achieves up to 17× training speedup while simultaneously improving accuracy—marking a significant breakthrough in the scalability of RNN-based models.
📝 Abstract
Recurrent neural networks (RNNs) with deep test-time memorization modules, such as Titans and TTT, represent a promising, linearly-scaling paradigm distinct from Transformers. While these expressive models do not yet match the peak performance of state-of-the-art Transformers, their potential has been largely untapped due to prohibitively slow training and low hardware utilization. Existing parallelization methods force a fundamental conflict governed by the chunksize hyperparameter: large chunks boost speed but degrade performance, necessitating a fixed, suboptimal compromise. To solve this challenge, we introduce TNT, a novel training paradigm that decouples training efficiency from inference performance through a two-stage process. Stage one is an efficiency-focused pre-training phase utilizing a hierarchical memory. A global module processes large, hardware-friendly chunks for long-range context, while multiple parallel local modules handle fine-grained details. Crucially, by periodically resetting local memory states, we break sequential dependencies to enable massive context parallelization. Stage two is a brief fine-tuning phase where only the local memory modules are adapted to a smaller, high-resolution chunksize, maximizing accuracy with minimal overhead. Evaluated on Titans and TTT models, TNT achieves a substantial acceleration in training speed-up to 17 times faster than the most accurate baseline configuration - while simultaneously improving model accuracy. This improvement removes a critical scalability barrier, establishing a practical foundation for developing expressive RNNs and facilitating future work to close the performance gap with Transformers.