π€ AI Summary
Agent-based LLMs generate tree-structured token trajectories during task execution, but conventional training methods decompose these into independent linear sequences, leading to redundant computation of shared prefixes and suboptimal efficiency.
Method: This paper introduces Tree Trainingβa novel training paradigm that natively incorporates tree-structured trajectories into a unified computational framework. It employs Tree Packing to enable cross-branch sharing of prefix computations, Gradient Restoration to ensure accurate gradient backpropagation through shared nodes, and a dedicated shared-prefix caching mechanism.
Contribution/Results: The method supports both supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF), achieving up to 3.9Γ training speedup across multiple open-source agent LLMs. It significantly improves throughput and GPU utilization, offering a scalable, efficient training paradigm for large-scale agent-based LLMs.
π Abstract
In agentic LLM scenarios, an agent's interaction process during a single rollout often exhibits branching behaviors. Due to memory retrieval and concurrent tool executions at certain decision points, the token trajectory of one task evolves into a tree-like structure rather than a linear sequence. However, current training pipelines decompose such tree-structured trajectories into separate linear segments, treating each branch as an independent sequence. As a result, shared prefixes across these branches are repeatedly recomputed during both forward and backward passes. To address this inefficiency, we propose Tree Training, a paradigm that computes each shared prefix only once and reuses its intermediate results across related branches during both forward and backward passes, substantially improving computation efficiency in large-scale agentic training. This is achieved via (i) Tree Packing, which efficiently reuses shared computations across trajectories, and (ii) Gradient Restoration, which ensures correct gradient propagation across reused prefixes. Experiments on multiple open-source models demonstrate up to 3.9x reduction in total training time, enabling more efficient agentic LLM SFT and RL training.