Cut Your Losses in Large-Vocabulary Language Models

📅 2024-11-13
🏛️ arXiv.org
📈 Citations: 1
Influential: 0
📄 PDF
🤖 AI Summary
In large-vocabulary language model training, cross-entropy loss computation incurs severe memory explosion due to materializing full logits—particularly hindering scaling of small models. To address this, we propose Cut Cross-Entropy (CCE): an online fusion of matrix multiplication and log-sum-exp reduction directly in GPU flash memory, exploiting the natural sparsity of softmax outputs to skip gradients below a precision-aware threshold. CCE employs custom CUDA kernels, flash-aware numerically stable log-sum-exp, and gradient sparsification driven by backward-error analysis. On Gemma-2 (2B), CCE reduces loss-layer memory from 24 GB to 1 MB and classification-head memory from 28 GB to 1 GB, with no degradation in training throughput or convergence behavior. To our knowledge, CCE is the first method enabling high-dimensional cross-entropy computation without logits materialization, establishing a lightweight, scalable loss-computation paradigm for large-vocabulary models.

Technology Category

Application Category

📝 Abstract
As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.
Problem

Research questions and friction points this paper is trying to address.

Reduces memory footprint in large-vocabulary language models
Optimizes cross-entropy loss computation without full logit matrix
Maintains training speed and convergence with reduced memory usage
Innovation

Methods, ideas, or system contributions that make the work stand out.

CCE computes cross-entropy without global logits.
Custom kernel performs operations in flash memory.
Sparsity of softmax improves gradient computation throughput.
🔎 Similar Papers
No similar papers found.