LiteFFN: How We Made LTX-2 19B Run Faster with Low-Rank Magic

LiteFFN post header image

TL;DR

We implemented an efficient technique to accelerate the feed-forward networks (FFN) in LTX-2, Lightricks' 19-billion parameter video generation model, achieving 11.5% peak memory reduction and 22.5% faster transformer computation, resulting in 7.6% faster end-to-end inference, with minimal quality loss.

The key insight: as attention mechanisms get faster (thanks to Flash Attention), FFN layers become the new bottleneck, and we can compress them using low-rank decomposition plus quantization.

This work builds on the theoretical foundation of SVDQuant, with practical implementation for H100 GPUs and real-world video generation workloads. LiteFFN repository: https://github.com/moonmath-ai/LiteFFN.

The Problem: Attention Got Fast, FFN Didn't

If you've been following the AI video generation space, you've probably heard of models like LTX-2 and OpenAI's Sora. These models can generate stunning video from text descriptions, but they are computationally expensive.

LTX-2 is a 19-billion parameter Diffusion Transformer that can generate native 4K video without upscaling, up to 50 FPS, synchronized audio in the same forward pass, and up to 20 seconds of high-fidelity video.

That power comes at a cost: even on an H100 GPU, generation can take [X] seconds and requires [X] GB of memory (or 12 GB minimum with FP8 quantization). This is manageable for research, but a major blocker for scaled production deployment.

Where Does the Time Go?

Transformer-based models have two main computational blocks:

  • Attention layers: capture relationships between different parts of the input
  • Feed-forward networks (FFN): process each position independently

The community has heavily optimized attention (Flash Attention 1-3, Sage Attention, LiteAttention), but FFN layers have been far less optimized.

As attention gets faster, FFN becomes the bottleneck. On LTX-2 inference, FFN is nearly 50% of transformer runtime, and this is the gap we targeted.

The Solution: Compress the FFN Smartly

Our approach combines two ideas:

  • Low-rank decomposition: approximate large weight matrices with smaller ones
  • Quantization: store the remaining error in fewer bits

The Intuition

In LTX-2, FFN matrices can be as large as 8192x2048 with tens of millions of parameters. Not all parameters carry equal signal, and many capture overlapping structure.

Low-rank decomposition represents one large matrix as the product of two smaller matrices, preserving most useful behavior at much lower compute and memory cost. Quantization then captures the leftover correction terms efficiently.[1]

How It Works

Step 1: Collect Calibration Data

We run the model on a sample of real prompts and collect per-layer input statistics to characterize typical activation structure.

For each FFN layer, we compute:

R = E[x * x^T]

For 2048-dimensional inputs, each autocorrelation matrix has 4 million entries. Across many layers, memory grows quickly, so we compute R incrementally:

# Pseudocode for incremental autocorrelation
R_sum = zeros((dim, dim))
N = 0

for batch in calibration_data:
    x = get_activations(batch)  # Shape: [batch, seq_len, dim]
    x_flat = x.reshape(-1, dim) # Flatten to [N_samples, dim]
    R_sum += x_flat.T @ x_flat  # Accumulate outer product
    N += x_flat.shape[0]

R = R_sum / N

Step 2: Compute Effective Weights

We use calibration statistics to transform the original weight matrix into a form that is easier to compress:

W_effective = W @ R^(1/2)

This reweights columns by practical usage, improving how efficiently SVD can approximate the matrix for real workloads.

Step 3: Decompose

W_effective = U @ S @ V^T

We keep the top-r singular components:

W_low_rank_eff = U[:, :r] @ S[:r] @ V[:r, :]^T

Step 4: Handle the Remainder

Remainder = W - W_low_rank

This residual is quantized to 4-bit or 8-bit formats.

Step 5: Inference

def forward(x):
    # Low-rank path (full precision, small matrices)
    y_lowrank = A @ (B @ x)

    # Remainder path (quantized)
    y_remainder = Q_quantized @ x

    return y_lowrank + y_remainder + bias

Quantization Options

Current production format on Hopper:

FP8 E4M3

  • Bits: 8 (4 exponent, 3 mantissa)
  • Range: approximately +/-448
  • Pros: native H100 tensor core support, strong accuracy
  • Cons: 2x memory versus FP4

Planned formats for LiteFFN on Blackwell and future Hopper support include NVFP4 E2M1 and MXFP4.

The Fun Part: Multiply-Free FP4

For non-native FP4 support, NVFP4's limited value set allows replacing floating-point multiplies with shift/add operations.

FP4 Value Operation Instructions
0return 0none
0.5x >> 11 shift
1xidentity
1.5x + (x >> 1)1 shift, 1 add
2x << 11 shift
3(x << 1) + x1 shift, 1 add
4x << 21 shift
6(x << 2) + (x << 1)2 shifts, 1 add

Results

Average Runtime

Group Transformer Mean (s) Min (s) Max (s) Std (s) Transformer % Faster Decode Mean (s) Save (s) E2E Total (s) E2E % Faster
baseline4.5204.4604.6500.0700.00%3.7105.10013.3300.00%
liteffn3.5003.4903.5200.01022.57%3.7105.10012.3107.65%

Allocated VRAM Reduction

Configuration Average Peak Peak Relative to Baseline
Original (FP16)59,919.00 MB65,663.00 MB100%
r=6444,517.00 MB58,433.00 MB89.0%
LiteFFN quality comparison figure

Video Samples (baseline vs LiteFFN r=32 / r=64 / r=512)

Click a thumbnail to play the video.

Baseline LiteFFN r=32 LiteFFN r=64 LiteFFN r=512
Baseline – water droplet r=32 – water droplet r=64 – water droplet r=512 – water droplet
Baseline – jetpack r=32 – jetpack r=64 – jetpack r=512 – jetpack
Baseline – cats boxing r=32 – cats boxing r=64 – cats boxing r=512 – cats boxing
Baseline – Rhine river r=32 – Rhine river r=64 – Rhine river r=512 – Rhine river
Baseline – underwater r=32 – underwater r=64 – underwater r=512 – underwater

Quality Measurements

PSNR average over different seeds, per prompt (dB)[2] (Higher is better; 20 dB corresponds to a mean squared error of roughly 10−2):

Prompt baseline vs r=32 baseline vs r=64 baseline vs r=512
a-dramatic-underwater-scene-featuring-a-person-s19.82319.82220.413
a-man-in-a-sleek-modern-jetpack-flying-upwards-t20.87221.51721.163
a-serene-view-of-the-banks-of-the-rhine-river-sh20.27720.20819.735
a-single-water-droplet-falls-from-a-height-movin27.82727.37530.680
two-anthropomorphic-cats-boxing-in-a-well-lit-ar18.64620.92421.188

Prompt PSNR Summary

Prompt Total seeds PSNR > 20 dB PSNR > 17 dB
a-dramatic-underwater-scene-featuring-a-person-s301030
a-man-in-a-sleek-modern-jetpack-flying-upwards-t303030
a-serene-view-of-the-banks-of-the-rhine-river-sh302030
a-single-water-droplet-falls-from-a-height-movin303030
two-anthropomorphic-cats-boxing-in-a-well-lit-ar302030

Performance Benchmark on Shapes Captured from LTX-Video

Multiplier columns are speedup ratios vs baseline linear (>1 faster, <1 slower).

Units:

  • per-shape rows: us
  • TOTAL row: ms

Column glossary:

  • Cfg: FFN projection shape (w1 = up-proj, w2 = down-proj).
  • M: flattened activation rows for that GEMM shape.
  • Count: number of calls for that shape in the captured workload.
  • Lin: baseline nn.Linear latency.
  • TE: Transformer Engine linear latency.
  • PT: LiteFFN PyTorch path latency.
  • CUDA: LiteFFN CUDA path latency.
  • TE_x / PT_x / CUDA_x: speedup multiplier vs baseline linear (>1 is faster, <1 is slower).
  • TOTAL: count-weighted aggregate across listed shapes.
LiteFFN benchmark plot
[3]
Cfg M Count Lin TE PT CUDA TE_x PT_x CUDA_x
w214003363922602981861.508x1.315x2.108x
w114003363782733011821.385x1.256x2.077x
w224503365974375013171.366x1.192x1.883x
w124503365624555113021.235x1.100x1.861x
w25600480121399412157621.220x0.998x1.592x
w156004801141109711987121.040x0.952x1.603x
w2980014420081763207412611.139x0.968x1.592x
w1980014419461868204212021.042x0.953x1.619x
w21085033622772110228113951.079x0.998x1.632x
w11085033620352156221513240.944x0.919x1.537x
w22240014448894506459630181.085x1.064x1.620x
w12240014447144581447128051.029x1.054x1.681x
w24340014492758285884557721.119x1.049x1.607x
w14340014488478460866953281.046x1.021x1.660x
TOTAL-384077887158763147441.088x1.021x1.642x

Future Work

  • Attention projection decomposition (Q/K/V/O)
  • Adaptive per-layer rank selection from singular value decay
  • Dynamic rank by denoising timestep
  • More optimized FP4 kernels with warp-level tuning
  • Extension to other video models, world models, and VLMs

Acknowledgments

This work would not exist without SVDQuant. We also thank the Lightricks team for open-sourcing LTX-2 and the Flash Attention authors for pushing efficient attention forward.

References