How Reinforcement Learning After Next-Token Prediction Facilitates Learning

📅 2025-10-13
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Large language models (LLMs) struggle to efficiently solve tasks requiring long-range dependencies—such as computing the parity of d-bit strings—under standard autoregressive next-token prediction. Method: We propose a reinforcement learning (RL) post-training framework built upon sequence-prediction pretraining. It incorporates hybrid short-and-long chain-of-thought (CoT) data and extended response generation at inference time, synergizing autoregressive Transformer architectures with linear theoretical analysis to yield an interpretable generalization theory. Contribution/Results: We theoretically prove that when the proportion of long CoT examples is non-exponentially small, models can learn d-bit parity efficiently. Empirical evaluation on Llama-family models demonstrates substantial gains over pure autoregressive baselines in sparse long-sequence settings and replicates RL post-training’s generalization benefits across multiple mathematical reasoning benchmarks. Our approach offers a novel pathway to reduce statistical and computational resource requirements for complex reasoning tasks.

Technology Category

Application Category

📝 Abstract
Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of $d$ bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of $d$ bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension $d$. Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.
Problem

Research questions and friction points this paper is trying to address.

Reinforcement learning enables generalization in rare long sequences
Next-token prediction struggles with parity tasks requiring extreme resources
Theoretical framework explains RL's advantage over next-token prediction
Innovation

Methods, ideas, or system contributions that make the work stand out.

Reinforcement learning optimizes pre-trained next-token prediction models
Training uses mixture distributions of short and long chain-of-thought sequences
Reinforcement learning enables generalization with rare long reasoning sequences