LiteAttention Callibration Tool: How to Manage Your Error Budget to Max Performance
TL;DR
- Per-layer, per-timestep threshold tuning: Easily set different thresholds for each LiteAttention instance and timestep to extract the maximum performance from a model.
- Simple to use: The tool makes threshold optimization straightforward without complex manual tuning.
- Learn more: LiteAttention GitHub - https://github.com/moonmath-ai/LiteAttention
LiteAttention skips attention tiles whose contribution falls below a threshold in log2 scale. A natural first question is: what threshold should you use?
The simplest approach, a single threshold for the entire model, already delivers meaningful speedups. In our early experiments, setting a fixed threshold across all layers and timesteps yielded sparsity levels of 50% or more with acceptable output quality. But this leaves performance on the table.
The reason is that not all attention is created equal. In a diffusion transformer like LTX-2, there are 48 attention layers, each processing a different level of abstraction. Some layers attend broadly (distributing attention weight across many tiles), while others focus sharply on a few key regions. A threshold that is conservative enough for the most sensitive layer will be too conservative for the rest.
The same is true across diffusion timesteps. Early timesteps operate on nearly pure noise, the attention pattern is broad and unstable. As denoising progresses, attention sharpens: the model increasingly “knows” where to look, and tiles that were initially relevant become redundant. Empirically, we observed that sparsity only increases across timesteps, a tile discarded at timestep t is almost always discarded at all subsequent timesteps. This means that later timesteps can tolerate much more aggressive thresholds than earlier ones.[1] (This observation is consistent with findings from other works. Papers such as Sparse-VideoGen and Radial Attention have independently documented that attention sparsity in video diffusion models varies significantly across layers and timesteps, and that exploiting this variance is key to efficient inference.)
So the question becomes: how do we assign a per-layer, per-timestep threshold to the entire model, without requiring the user to manually tune hundreds of values?
The Calibration Idea
The goal of calibration is to reduce this to a single scalar that the user controls, the target error, while letting the system find the right threshold for every (layer, timestep) pair automatically.
The insight behind calibration is local: for each LiteAttention instance at each timestep, we can measure how much a candidate threshold changes the attention output, and search for the threshold that keeps this change within a budget. This measurement is cheap because it happens inside a single attention call, no need to propagate through the rest of the model.
This is a heuristic: it assumes that controlling local error at each layer is a reasonable proxy for controlling global output quality. In practice, this works well, the calibrated thresholds produce outputs that are close to full-attention results, both numerically and visually.
The Algorithm
Calibration runs as a special mode during a single denoising pass. Instead of using a fixed threshold, each LiteAttention module performs a binary search over the threshold space [−20.0, 0.0] (in log2 scale) to find the value that produces the desired target error.
Each step of the binary search works as follows:
- Run attention with the candidate threshold, reading from the current skip list (inherited from the previous timestep). This produces
output_Aand writes a new skip list reflecting the candidate threshold. - Run attention again with the same threshold, but now reading from the skip list that was just generated. This produces
output_B. - Measure the error between
output_Aandoutput_B.[2] (Three error metrics are supported: relative L1 (default), cosine dissimilarity, and RMSE. The default target error is 0.01 (1% relative L1 deviation).)
The error captures how much the skip list would change between consecutive timesteps at this threshold. If the error is small, the threshold is stable, the skip pattern it produces is self-consistent. If the error is large, the threshold is too aggressive: the skip list is fluctuating between what should be computed and what should be skipped.
The binary search converges when the measured error is within 10% of the target, or after 30 iterations. In practice, convergence happens in ~10 iterations due to the monotone relationship between threshold and error.[3] (This value is used for the calculation of the transformer output, to make sure that the next values are calculated on the correct input.)
After calibration completes, each module has recorded one threshold per timestep. These are saved to a TOML file and can be loaded for all subsequent inference runs, no recalibration needed.
Early Timesteps: The disabled_steps Heuristic
One practical finding from calibration is that the first few denoising timesteps are special. The input is nearly pure noise, attention patterns are chaotic, and there is no meaningful skip list to build on. Running calibration on these steps produces unstable thresholds that don't generalize.
The solution is simple: skip the first N timesteps entirely. The disabled_steps parameter prepends N steps where LiteAttention runs as regular FlashAttention, no skipping, no skip list updates. The calibrated (or constant) threshold kicks in only after the skip list has had a chance to warm up from stable attention patterns. This is configured as a single integer and applies uniformly to all layers.
The Registry: Managing Thresholds at Scale
The per-layer, per-timestep threshold configuration is managed by the LiteAttentionRegistry, a framework that discovers all LiteAttention modules in a model and coordinates their configuration.
Discovery
The registry walks model.named_modules() and collects every LiteAttention instance by its dotted path (e.g., transformer_blocks.0.attn1.lite_attention, transformer_blocks.47.attn1.lite_attention). In LTX-2, this discovers all 48 video self-attention layers. Other attentions (text cross-attention, audio, cross-modal) are not LiteAttention instances and are ignored.
Four Modes
The registry supports four modes, selectable via a single parameter:
| Mode | Behavior |
|---|---|
const | Same fixed threshold for all modules and timesteps |
calib | Run calibration; save per-module, per-timestep thresholds to TOML |
load | Load previously calibrated thresholds from TOML |
disable | Run LiteAttention kernel with no skipping (baseline) |
The typical workflow is: run once with mode="calib", then switch to mode="load" for all subsequent runs. In more advanced modes, the registry can act with a fully flexible config, mixing calibrations and threshold values and controlling each module and timestep separately.
Usage
From a consuming pipeline, the entire setup is a few lines:
# Calibration run (once)
registry = LiteAttentionRegistry.from_model(
model,
mode="calib",
filename="my_model_calibrated.toml",
calib_config={"target_error": 0.01, "metric": "L1"},
disabled_steps=3,
)
# ... run the denoising loop ...
registry.save() # writes per-module, per-timestep thresholds
# Production runs (every time after)
registry = LiteAttentionRegistry.from_model(
model,
mode="load",
filename="my_model_calibrated.toml",
disabled_steps=3,
)
# ... run the denoising loop - thresholds are applied automatically
The Config File
The calibration output is a human-readable TOML file. Each section corresponds to a module, and the threshold field is a list with one value per timestep:
["transformer_blocks.0.attn1.lite_attention"]
_type = "LiteAttentionRunConfig"
threshold = [0.0, 0.0, 0.0, -4.2, -3.8, -5.1, -6.3, ...]
["transformer_blocks.1.attn1.lite_attention"]
_type = "LiteAttentionRunConfig"
threshold = [0.0, 0.0, 0.0, -7.1, -5.5, -4.8, -3.2, ...]
The leading 0.0 values correspond to disabled_steps (where LiteAttentionDisabledConfig runs as full attention). The remaining values are the calibrated thresholds, note how they vary across layers and timesteps. Users can also edit this file directly.
Integration with LTX-2[4]
See our integration example repo: https://github.com/moonmath-ai/LTX-2-LiteAttention
LTX-2 is a two-stage video diffusion pipeline: a low-resolution stage followed by a spatial upsampling stage. Both stages share the same 48-layer transformer, but the attention patterns differ between stages (different resolutions, different noise levels).
The pipeline creates separate registries for each stage, saving to my_model.stage_1.toml and my_model.stage_2.toml. This means calibration produces independent thresholds for each stage, the low-res stage might tolerate more aggressive skipping than the high-res one.
The integration required minimal code changes. LiteAttention modules are instantiated inside the Attention class (one per video self-attention layer), and the registry hooks into them non-invasively via PyTorch's named_modules(). The pipeline creates the registry before the denoising loop and calls save() after, roughly 10 lines of glue code.
Calibration Cost
Calibration adds overhead to a single inference run. Each binary search iteration runs the attention kernel twice, and with up to 30 iterations per module per timestep, this adds up. In practice, calibration takes roughly 20% to 5x longer than a normal inference run, depending on the model size and number of timesteps.
But this cost is paid once. The calibrated thresholds are saved to a TOML file that is loaded in all subsequent runs with zero overhead beyond a file read at startup.
Debugging Infrastructure
The registry also serves as an entry point for research and analysis. The enable_capture() method instruments all LiteAttention modules to record diagnostic data during inference, from the fraction of tiles that were computed, to attention maps, skip lists, and running statistics, on a user-selected subset of modules, timesteps, and heads.
The captured data can be visualized offline in various ways. This infrastructure is useful for understanding attention behavior in general: discovering which layers are sparse, how sparsity evolves across timesteps, and whether the skip list is making sensible decisions.
Looking Forward
Calibration as described here is a local heuristic, it optimizes each layer independently. A natural next step is to account for error propagation: a small error in an early layer might amplify through subsequent layers, while a larger error in a late layer might be harmless.
The registry framework is designed with this kind of analysis in mind. By capturing attention patterns and output statistics across the entire model, we can study how errors propagate and where the model is most sensitive. This opens the door to more sophisticated calibration strategies, potentially assigning not just thresholds but entirely different optimization techniques (quantization, decomposition, estimation) at different points in the model, based on measured sensitivity.
Work such as Radial Attention has begun exploring this direction, recognizing that different model locations may benefit from fundamentally different optimization approaches. The registry infrastructure we've built provides the observability needed to pursue this in practice.
Check out the LiteAttention repository: https://github.com/moonmath-ai/LiteAttention