Parallax Boosts Attention: Learned Covariance Fixes Softmax Lag

The Transformer’s attention mechanism has seen little change since 2017, and most work on efficiency tries to replace softmax altogether. This often leads to complex rewrites and limited gains. A new approach, called Parallax, takes a different path: it keeps the familiar softmax core and adds a lightweight correction branch that models key‑value covariance. By learning a projection matrix instead of solving a per‑query linear system, Parallax removes the heavy conjugate‑gradient solver that caused high I/O, precision issues, and a tricky regularization trade‑off. The result is a mechanism that adds compute deliberately but re‑uses the same memory stream as FlashAttention, doubling arithmetic intensity and making the operation more compute‑bound – exactly where modern GPU kernels excel.

In experiments on synthetic tasks and LLM pretraining at 0.6B and 1.7B scales, Parallax with the Muon optimizer achieved the best perplexity and higher downstream accuracy than standard Transformers, Mamba, Gated DeltaNet, MesaNet, and Kimi DeltaAttention. Ablations show that the gain comes from the mechanism itself, not just extra parameters or compute. However, the advantage is tightly linked to Muon; under AdamW the correction branch is largely suppressed and the benefit shrinks. Parallax also produces attention scores that can go negative or exceed one, allowing it to subtract irrelevant content and reduce the attention sink on the first token.

Strengths:
– Keeps softmax intact, so a pretrained model can be adapted by adding the projection matrix and fine‑tuning.
– Adds no extra I/O per iteration by sharing the FlashAttention key‑value stream.
– Increases arithmetic intensity, with a prototype decode kernel matching or beating FlashAttention 2/3 on Hopper GPUs.
– Shows consistent perplexity and downstream wins under parameter‑matched and compute‑matched controls.

Weaknesses/open questions:
– Performance depends heavily on Muon; the edge fades with AdamW.
– The exact reason for this optimizer interaction remains unclear.
– Results stop at 1.7B scale, without MoE, longer contexts, or larger runs.
– The benefit weakens during weight‑decay decay phases.

#AI #ML #Attention #LLM #Optimizer #GPU