🤖 AI Summary
This work proposes a novel approach to enhance the robustness and out-of-distribution generalization of Transformers without compromising in-distribution performance. The training process is formulated as a constrained optimization problem, introducing—for the first time—a layer-wise monotonic descent constraint that enforces intermediate representations across layers to progressively reduce the expected loss. This is achieved through a primal-dual training mechanism, yielding an Unrolled Transformer architecture that explicitly mimics the behavior of iterative optimization algorithms. Experimental results demonstrate that the proposed method significantly improves robustness and generalization on both video denoising and text classification tasks, while maintaining competitive performance on in-distribution data.
📝 Abstract
We introduce a constrained optimization framework for training transformers that behave like optimization descent algorithms. Specifically, we enforce layerwise descent constraints on the objective function and replace standard empirical risk minimization (ERM) with a primal-dual training scheme. This approach yields models whose intermediate representations decrease the loss monotonically in expectation across layers. We apply our method to both unrolled transformer architectures and conventional pretrained transformers on tasks of video denoising and text classification. Across these settings, we observe constrained transformers achieve stronger robustness to perturbations and maintain higher out-of-distribution generalization, while preserving in-distribution performance.