🤖 AI Summary
This work addresses the instability of the encoder–decoder interface in language models caused by conventional weight tying, which undermines training optimization and the predictability of post-training interventions. The authors propose Pseudo-Inverse Tying (PIT), a method that treats input embeddings and output projections as coupled mappings onto a shared latent token memory, enforcing pseudo-inverse consistency to enhance stability. PIT introduces a shared memory structure initialized via polar decomposition or random orthogonal matrices and incorporates a fully learnable symmetric positive-definite transformation in the latent space, thereby avoiding explicit pseudo-inverse computation and additional vocabulary parameters. By integrating numerically stable techniques—including thin polar decomposition, Cholesky factorization, and triangular solves—the approach enables efficient and consistent embedding and unembedding operations. Experiments on edge-deployable models ranging from 256M to 1.3B parameters demonstrate markedly improved training stability, stronger inter-layer semantic consistency, and substantially reduced side effects from post-hoc interventions.
📝 Abstract
Weight tying is widely used in compact language models to reduce parameters by sharing the token table between the input embedding and the output projection. However, weight sharing does not guarantee a stable token interface: during training, the correspondence between encoding tokens into hidden states and decoding hidden states into logits can drift, worsening optimization sensitivity and making post-training interventions such as editing, patching, and lightweight adaptation less predictable. We propose Pseudo-Inverse Tying (PIT), which synchronizes embedding and unembedding as coupled projections of a shared latent token memory, guaranteeing a pseudo-inverse-consistent interface throughout training. PIT maintains an orthonormal shared memory, obtained by thin polar decomposition for teacher initialization or random orthonormal initialization from scratch, and introduces a fully learned symmetric positive definite hidden-space transform parameterized via a Cholesky factor. The output head applies this transform to hidden states before the vocabulary projection, while the embedding applies the inverse transform to token vectors using stable triangular solves, avoiding explicit pseudo-inverse recomputation and any vocabulary-sized auxiliary parameters. We evaluate PIT on on-device models spanning 256M-1.3B parameters across pretraining and adaptation, and consistently observe improved training stability, stronger layerwise semantic consistency, and substantially reduced side effects.