🤖 AI Summary
To address the fundamental mismatch between teacher-forced training and autoregressive inference—along with associated complexities in state management and error-proneness in sequence modeling—this paper introduces a unified, streamable neural network layer API. The core innovation is an explicit state interface coupled with a `step()` method, enabling identical numerical behavior under both parallel (training-time) and incremental (inference-time) execution modes within the same layer. By abstracting temporal states—including KV caches, convolutional buffers, and RNN hidden states—the API supports declarative, composable layer design in JAX and TensorFlow 2. An open-source library provides a comprehensive suite of streamable primitive layers and composition utilities. This framework significantly lowers the barrier to developing and deploying streaming models while ensuring production-grade correctness, efficiency, and maintainability.
📝 Abstract
We introduce a neural network layer API and library for sequence modeling, designed for easy creation of sequence models that can be executed both layer-by-layer (e.g., teacher-forced training) and step-by-step (e.g., autoregressive sampling). To achieve this, layers define an explicit representation of their state over time (e.g., a Transformer KV cache, a convolution buffer, an RNN hidden state), and a step method that evolves that state, tested to give identical results to a stateless layer-wise invocation. This and other aspects of the SequenceLayers contract enables complex models to be immediately streamable, mitigates a wide range of common bugs arising in both streaming and parallel sequence processing, and can be implemented in any deep learning library. A composable and declarative API, along with a comprehensive suite of layers and combinators, streamlines the construction of production-scale models from simple streamable components while preserving strong correctness guarantees. Our current implementations of SequenceLayers (JAX, TensorFlow 2) are available at https://github.com/google/sequence-layers.