Scaling & Benchmarks

fastcxt dramatically improves inference speed over cxt through two complementary strategies: replacing autoregressive sampling with a single forward pass, and optionally exploiting tree topology to reduce the number of predictions from O(n²) pairs to O(n) internal nodes.

Scaling comparison

Theoretical scaling calibrated to measured per-pair / per-node runtimes. Panel A: absolute runtime on log-log axes. Panel B: speedup over cxt. Panel C: local scaling exponent (slope of log-log curve).

The three modes

Method

Architecture

Pair scaling

Per-pair cost

Bottleneck

cxt (baseline)

Decoder-only transformer

O(n²)

15 reps × 500 autoregressive steps

Stochastic sampling dominates

fastcxt (pairwise)

Bidirectional Mamba encoder-decoder

O(n²)

1 forward pass (all windows at once)

Quadratic pair count at large n

fastcxt + tsinfer

Mamba encoder-decoder

O(n)

1 forward pass per node + O(n log n) LCA

Tree inference is the new bottleneck

Why cxt is slow

cxt uses a decoder-only transformer that generates TMRCA predictions autoregressively — one token (window) at a time, 500 tokens per sequence. To produce uncertainty estimates it repeats this 15 times with stochastic sampling, then averages. For each sample pair that’s 15 × 500 = 7,500 forward passes. With n(n−1)/2 pairs this becomes intractable beyond a few dozen samples.

Why fastcxt pairwise is ~125× faster

fastcxt replaces autoregressive generation with a single forward pass through a bidirectional Mamba encoder-decoder that produces all 500 window predictions simultaneously, with built-in mean and variance (no sampling needed). The 125× speedup comes from eliminating the 15 × 500 = 7,500× overhead per pair. The quadratic pair count remains — at n=1000 diploids there are still ~2M pairs.

Why fastcxt + tsinfer scales as O(n log n)

When tree topology from tsinfer is available, instead of predicting all n(n−1)/2 pairwise TMRCAs independently, fastcxt predicts the O(n) internal node times in the tree. Any pairwise TMRCA is then a lowest common ancestor (LCA) lookup which takes O(log n) per query. Total cost: O(n) predictions + O(n² log n) lookups, but since the lookups are trivial table operations the practical scaling is near-linear.

Running the benchmarks

1. Benchmark fastcxt (both modes)

The fastcxt-benchmark CLI runs both pairwise and tree-aware modes across a range of sample sizes:

# Both modes, default sample sizes (5, 10, 25, 50, 100)
fastcxt-benchmark --mode all --device cuda:0

# Pairwise only, custom sizes
fastcxt-benchmark --mode fastcxt_notree \
    --sample-sizes 10 25 50 100 200 \
    --batch-size 128 --device cuda:0

# Tree-aware only
fastcxt-benchmark --mode fastcxt_tree \
    --sample-sizes 10 25 50 100 200 500 \
    --device cuda:0

# Save results to JSON
fastcxt-benchmark --mode all \
    --sample-sizes 5 10 25 50 100 200 \
    --output benchmarks/fastcxt_scaling.json

Output:

Model params: 1,234,567

--- n_samples=10 (n_haploids=20) ---
  pairwise: 190 pairs, preproc=0.234s, infer=0.045s, total=0.279s
  tree:     19 nodes (covers 190 pairs), preproc=0.012s, infer=0.003s, total=0.015s

--- n_samples=50 (n_haploids=100) ---
  pairwise: 4950 pairs, preproc=5.123s, infer=0.891s, total=6.014s
  tree:     99 nodes (covers 4950 pairs), preproc=0.045s, infer=0.008s, total=0.053s

Scaling summary:
mode               n_hap  pairs/nodes    total_s
--------------------------------------------------
fastcxt_notree        20          190     0.2790
fastcxt_tree          20           19     0.0150
fastcxt_notree       100         4950     6.0140
fastcxt_tree         100           99     0.0530

2. Benchmark cxt (baseline, optional)

To compare against the original cxt transformer, install the separate cxt package and run its benchmark script with pre-simulated tree sequences and a trained checkpoint. The theoretical cxt scaling curves in the comparison plot are calibrated from measured runtimes (~0.05 s per pair).

3. Generate the scaling plot

The showcase plotting script includes a scaling comparison figure using theoretical scaling curves calibrated to measured runtimes:

python scripts/plot_atlas_showcase.py --outdir figures/

This generates figures/08_scaling_comparison.png along with all other showcase figures. The scaling curves use:

  • cxt: 0.05 s per pair (measured: 15 reps × 500 tokens × GPU overhead)

  • fastcxt pairwise: 0.0004 s per pair (measured: single batched forward pass)

  • fastcxt tsinfer: 0.0004 s per node + O(n log n) LCA overhead

Interpreting the results

Panel A — Runtime vs sample size (log-log)

On log-log axes, the slope reveals the scaling exponent:

  • cxt (grey): slope ≈ 2.0 → O(n²). At n=100 diploids, ~25,000 seconds (~7 hours).

  • fastcxt pairwise (blue): slope ≈ 2.0 → still O(n²), but shifted down by ~125×. At n=100 it takes ~2 minutes.

  • fastcxt tsinfer (green): slope ≈ 1.0 → O(n). At n=1000 it takes ~1 second.

The horizontal reference lines mark 1 minute and 1 hour.

Panel B — Speedup over cxt

Shows the multiplicative speedup of each fastcxt mode relative to cxt.

  • At n=10, fastcxt pairwise is ~19× faster (the per-pair constant dominates).

  • At n=1000, fastcxt+tsinfer is >40,000× faster, because it predicts O(n) nodes instead of O(n²) pairs.

Panel C — Local scaling exponent

The instantaneous slope of the log-log curve, computed as finite differences.

  • cxt and fastcxt pairwise both converge to exponent ≈ 2.0 (quadratic).

  • fastcxt+tsinfer converges to exponent ≈ 1.0 (linear), confirming the theoretical O(n log n) ≈ O(n) scaling.

When to use which mode

Scenario

Recommended mode

Why

Quick test (< 25 samples)

fastcxt pairwise

Simple, no tree inference needed, fast enough

Medium cohort (25–100 samples)

fastcxt pairwise or fastcxt tsinfer

Pairwise still tractable; tsinfer gives 10–50× speedup

Large cohort (100–1000+ samples)

fastcxt tsinfer

Quadratic pair count is intractable; tsinfer required

No SNP data (only trees)

fastcxt tsinfer

Direct tree input, no genotype matrix needed

Benchmarking / comparing to cxt

All three modes

Run fastcxt-benchmark --mode all

Reproducing the benchmark

Full reproduction from scratch (requires GPU):

# 1. Install
uv pip install -e ".[all]"

# 2. Run fastcxt benchmarks
fastcxt-benchmark --mode all \
    --sample-sizes 5 10 25 50 100 200 500 \
    --device cuda:0 \
    --output benchmarks/fastcxt_results.json

# 3. Generate scaling figure
python scripts/plot_atlas_showcase.py --outdir figures/

The fastcxt-benchmark command simulates tree sequences on-the-fly via msprime, so no pre-existing data files are needed.