fastcxt

Fast pairwise coalescence time inference with Mamba state-space models

Single-pass · Built-in uncertainty · Mutation-rate conditioned · O(n) with tree topology

Under active development

APIs, documentation, and results may change without notice. Not yet recommended for production use.

fastcxt predicts pairwise time to most recent common ancestor (TMRCA) from genotype data using a bidirectional Mamba encoder-decoder. It replaces the autoregressive transformer from cxt with a single-pass architecture that produces means and calibrated variances for all genomic windows in one forward pass — no stochastic sampling, no post-hoc correction.

Quick Start

Install and run inference in five minutes.

Quickstart
Tutorial

Full pipeline: simulate, train, infer, visualize.

End-to-End Tutorial
Algorithm

Bidirectional Mamba, FiLM conditioning, Beta-NLL loss.

Architecture
cxt vs fastcxt

Architecture comparison and migration guide.

cxt vs fastcxt
Mosquito Protocol

Ag1000G analysis: inversions, karyotypes, selection scans.

Mosquito Analysis Protocol
Figure Gallery

Publication-quality plots from the Ag1000G analysis.

Figure Gallery
Demography

IICR estimation and Ne(t) from TMRCA distributions.

Demographic Inference
Geographic

Maps, sparklines, and spatial TMRCA patterns.

Visualization
API Reference

Full Python API for all modules.

API Reference

How it works

1. Build SFS features
Site frequency spectrum in XOR/XNOR channels from a genotype matrix — same representation as cxt.
2. Single forward pass
Bidirectional Mamba encoder reads the full sequence, decoder outputs (μ, log σ²) for every window.
3. FiLM conditioning
Mutation rate injected via learned scale/shift at each encoder layer — no post-hoc correction needed.
4. Calibrated uncertainty
Beta-NLL loss directly models variance alongside the mean — 95% CI = exp(μ ± 1.96√σ²).

Minimal example

import tskit
from fastcxt.translate import translate_from_genotype_matrix

ts = tskit.load("data.trees")
gm = ts.genotype_matrix().T
positions = ts.tables.sites.position

pairs = [(0, 1), (0, 2), (1, 2)]
blocks = [(i, i + 100_000) for i in range(0, 1_000_000, 100_000)]

means, variances, index_map = translate_from_genotype_matrix(
    gm, positions, model,
    blocks=blocks, pivot_pairs=pairs,
    mutation_rate=3.5e-9, device="cuda:0",
    batch_size=128, build_workers=64,
)