Scaling & Benchmarks¶
fastcxt dramatically improves inference speed over cxt through two complementary strategies: replacing autoregressive sampling with a single forward pass, and optionally exploiting tree topology to reduce the number of predictions from O(n²) pairs to O(n) internal nodes.
Theoretical scaling calibrated to measured per-pair / per-node runtimes. Panel A: absolute runtime on log-log axes. Panel B: speedup over cxt. Panel C: local scaling exponent (slope of log-log curve).
The three modes¶
Method |
Architecture |
Pair scaling |
Per-pair cost |
Bottleneck |
|---|---|---|---|---|
cxt (baseline) |
Decoder-only transformer |
O(n²) |
15 reps × 500 autoregressive steps |
Stochastic sampling dominates |
fastcxt (pairwise) |
Bidirectional Mamba encoder-decoder |
O(n²) |
1 forward pass (all windows at once) |
Quadratic pair count at large n |
fastcxt + tsinfer |
Mamba encoder-decoder |
O(n) |
1 forward pass per node + O(n log n) LCA |
Tree inference is the new bottleneck |
Why cxt is slow¶
cxt uses a decoder-only transformer that generates TMRCA predictions autoregressively — one token (window) at a time, 500 tokens per sequence. To produce uncertainty estimates it repeats this 15 times with stochastic sampling, then averages. For each sample pair that’s 15 × 500 = 7,500 forward passes. With n(n−1)/2 pairs this becomes intractable beyond a few dozen samples.
Why fastcxt pairwise is ~125× faster¶
fastcxt replaces autoregressive generation with a single forward pass through a bidirectional Mamba encoder-decoder that produces all 500 window predictions simultaneously, with built-in mean and variance (no sampling needed). The 125× speedup comes from eliminating the 15 × 500 = 7,500× overhead per pair. The quadratic pair count remains — at n=1000 diploids there are still ~2M pairs.
Why fastcxt + tsinfer scales as O(n log n)¶
When tree topology from tsinfer is available, instead of predicting
all n(n−1)/2 pairwise TMRCAs independently, fastcxt predicts the O(n)
internal node times in the tree. Any pairwise TMRCA is then a
lowest common ancestor (LCA) lookup which takes O(log n) per query.
Total cost: O(n) predictions + O(n² log n) lookups, but since the lookups
are trivial table operations the practical scaling is near-linear.
Running the benchmarks¶
1. Benchmark fastcxt (both modes)¶
The fastcxt-benchmark CLI runs both pairwise and tree-aware modes
across a range of sample sizes:
# Both modes, default sample sizes (5, 10, 25, 50, 100)
fastcxt-benchmark --mode all --device cuda:0
# Pairwise only, custom sizes
fastcxt-benchmark --mode fastcxt_notree \
--sample-sizes 10 25 50 100 200 \
--batch-size 128 --device cuda:0
# Tree-aware only
fastcxt-benchmark --mode fastcxt_tree \
--sample-sizes 10 25 50 100 200 500 \
--device cuda:0
# Save results to JSON
fastcxt-benchmark --mode all \
--sample-sizes 5 10 25 50 100 200 \
--output benchmarks/fastcxt_scaling.json
Output:
Model params: 1,234,567
--- n_samples=10 (n_haploids=20) ---
pairwise: 190 pairs, preproc=0.234s, infer=0.045s, total=0.279s
tree: 19 nodes (covers 190 pairs), preproc=0.012s, infer=0.003s, total=0.015s
--- n_samples=50 (n_haploids=100) ---
pairwise: 4950 pairs, preproc=5.123s, infer=0.891s, total=6.014s
tree: 99 nodes (covers 4950 pairs), preproc=0.045s, infer=0.008s, total=0.053s
Scaling summary:
mode n_hap pairs/nodes total_s
--------------------------------------------------
fastcxt_notree 20 190 0.2790
fastcxt_tree 20 19 0.0150
fastcxt_notree 100 4950 6.0140
fastcxt_tree 100 99 0.0530
2. Benchmark cxt (baseline, optional)¶
To compare against the original cxt transformer, install the separate cxt package and run its benchmark script with pre-simulated tree sequences and a trained checkpoint. The theoretical cxt scaling curves in the comparison plot are calibrated from measured runtimes (~0.05 s per pair).
3. Generate the scaling plot¶
The showcase plotting script includes a scaling comparison figure using theoretical scaling curves calibrated to measured runtimes:
python scripts/plot_atlas_showcase.py --outdir figures/
This generates figures/08_scaling_comparison.png along with all
other showcase figures. The scaling curves use:
cxt: 0.05 s per pair (measured: 15 reps × 500 tokens × GPU overhead)
fastcxt pairwise: 0.0004 s per pair (measured: single batched forward pass)
fastcxt tsinfer: 0.0004 s per node + O(n log n) LCA overhead
Interpreting the results¶
Panel A — Runtime vs sample size (log-log)¶
On log-log axes, the slope reveals the scaling exponent:
cxt (grey): slope ≈ 2.0 → O(n²). At n=100 diploids, ~25,000 seconds (~7 hours).
fastcxt pairwise (blue): slope ≈ 2.0 → still O(n²), but shifted down by ~125×. At n=100 it takes ~2 minutes.
fastcxt tsinfer (green): slope ≈ 1.0 → O(n). At n=1000 it takes ~1 second.
The horizontal reference lines mark 1 minute and 1 hour.
Panel B — Speedup over cxt¶
Shows the multiplicative speedup of each fastcxt mode relative to cxt.
At n=10, fastcxt pairwise is ~19× faster (the per-pair constant dominates).
At n=1000, fastcxt+tsinfer is >40,000× faster, because it predicts O(n) nodes instead of O(n²) pairs.
Panel C — Local scaling exponent¶
The instantaneous slope of the log-log curve, computed as finite differences.
cxt and fastcxt pairwise both converge to exponent ≈ 2.0 (quadratic).
fastcxt+tsinfer converges to exponent ≈ 1.0 (linear), confirming the theoretical O(n log n) ≈ O(n) scaling.
When to use which mode¶
Scenario |
Recommended mode |
Why |
|---|---|---|
Quick test (< 25 samples) |
|
Simple, no tree inference needed, fast enough |
Medium cohort (25–100 samples) |
|
Pairwise still tractable; tsinfer gives 10–50× speedup |
Large cohort (100–1000+ samples) |
|
Quadratic pair count is intractable; tsinfer required |
No SNP data (only trees) |
|
Direct tree input, no genotype matrix needed |
Benchmarking / comparing to cxt |
All three modes |
Run |
Reproducing the benchmark¶
Full reproduction from scratch (requires GPU):
# 1. Install
uv pip install -e ".[all]"
# 2. Run fastcxt benchmarks
fastcxt-benchmark --mode all \
--sample-sizes 5 10 25 50 100 200 500 \
--device cuda:0 \
--output benchmarks/fastcxt_results.json
# 3. Generate scaling figure
python scripts/plot_atlas_showcase.py --outdir figures/
The fastcxt-benchmark command simulates tree sequences on-the-fly via
msprime, so no pre-existing data files are needed.