🤖 AI Summary
In end-to-end training of retrieval-augmented generation (RAG) models, marginalizing discrete retrieved passages—treated as latent variables—introduces biased and high-variance gradient estimates. To address this, we propose JSA-RAG, a framework inspired by the Expectation-Maximization (EM) algorithm that employs a Joint Stochastic Approximation (JSA) mechanism. JSA-RAG enables coupled optimization of the retriever and generator without reparameterization or reinforcement learning, substantially reducing gradient variance and improving marginalization stability. Compared to top-K approximation and variational RAG (VRAG), JSA-RAG achieves a superior trade-off between generation quality and retrieval accuracy. Extensive experiments across five open-domain question answering and knowledge-intensive dialogue datasets demonstrate that JSA-RAG significantly outperforms baselines in retrieval precision, answer faithfulness, and end-to-end convergence stability—validating both the effectiveness and generalizability of its gradient estimation strategy.
📝 Abstract
Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.