Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation

📅 2025-08-25
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
In end-to-end training of retrieval-augmented generation (RAG) models, marginalizing discrete retrieved passages—treated as latent variables—introduces biased and high-variance gradient estimates. To address this, we propose JSA-RAG, a framework inspired by the Expectation-Maximization (EM) algorithm that employs a Joint Stochastic Approximation (JSA) mechanism. JSA-RAG enables coupled optimization of the retriever and generator without reparameterization or reinforcement learning, substantially reducing gradient variance and improving marginalization stability. Compared to top-K approximation and variational RAG (VRAG), JSA-RAG achieves a superior trade-off between generation quality and retrieval accuracy. Extensive experiments across five open-domain question answering and knowledge-intensive dialogue datasets demonstrate that JSA-RAG significantly outperforms baselines in retrieval precision, answer faithfulness, and end-to-end convergence stability—validating both the effectiveness and generalizability of its gradient estimation strategy.

Technology Category

Application Category

📝 Abstract
Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.
Problem

Research questions and friction points this paper is trying to address.

Optimizing end-to-end training for retrieval-augmented generation models
Addressing biased gradient estimates in discrete latent variable marginalization
Improving performance on question answering and knowledge-grounded dialogues
Innovation

Methods, ideas, or system contributions that make the work stand out.

Joint stochastic approximation for RAG training
Stochastic EM extension for discrete latent variables
Reduces gradient variance in end-to-end optimization
🔎 Similar Papers
No similar papers found.
H
Hongyu Cao
Speech Processing and Machine Intelligence (SPMI) Lab, Tsinghua University, China
Yuxuan Wu
Yuxuan Wu
Embry-Riddle Aeronautical University
CompositeProcess designComplex system modeling
Y
Yucheng Cai
Speech Processing and Machine Intelligence (SPMI) Lab, Tsinghua University, China
X
Xianyu Zhao
TasiTech Co., Ltd., China
Z
Zhijian Ou
Speech Processing and Machine Intelligence (SPMI) Lab, Tsinghua University, China