JAX patterns for memory modeling

jaxcmr uses JAX for fast, reproducible simulations. This page covers jaxcmr-specific patterns. For JAX fundamentals, see the JAX Quickstart and Sharp Bits.

Why JAX for Memory Modeling?

Memory models require:

  • Fast simulation — Thousands of trials to evaluate a parameter set

  • Parallel fitting — Multiple subjects, multiple models

  • Reproducibility — Identical seeds yield identical results

JAX addresses all three: jit compiles functions, vmap vectorizes across trials, and explicit PRNGKey ensures reproducibility.

Models as PyTrees

jaxcmr models are JAX PyTrees—nested structures that work with JAX transformations:

Code
import jax
from jaxcmr.models.cmr import CMR

params = {
    "encoding_drift_rate": 0.5,
    "start_drift_rate": 0.5,
    "recall_drift_rate": 0.5,
    "learning_rate": 0.5,
    "primacy_scale": 2.0,
    "primacy_decay": 0.8,
    "shared_support": 0.05,
    "item_support": 0.25,
    "choice_sensitivity": 0.6,
    "stop_probability_scale": 0.05,
    "stop_probability_growth": 0.2,
}
model = CMR(list_length=16, parameters=params)

# Inspect model structure
jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else type(x).__name__, model)

Model methods return new instances rather than mutating state, enabling safe use inside JIT.

Batching Trials with vmap

Use vmap to simulate multiple trials in parallel. Each trial needs its own random key:

Code
def simulate_single_trial(key, presentations):
    """Simulate one trial given item presentations."""
    model = CMR(list_length=len(presentations), parameters=params)
    for item_id in presentations:
        model = model.experience(item_id)
    model = model.start_retrieving()
    return model.outcome_probabilities()

# Batch of 100 trials, each with 16 items
import jax.numpy as jnp
n_trials = 100
keys = jax.random.split(jax.random.PRNGKey(0), n_trials)
presentations = jnp.tile(jnp.arange(1, 17), (n_trials, 1))

# Vectorize across trials
batch_simulate = jax.vmap(simulate_single_trial)
all_probs = batch_simulate(keys, presentations)
print(f"Shape: {all_probs.shape}  # (trials, outcomes)")

JIT for Fitting

Wrap loss functions in jit for fast parameter optimization:

Code
@jax.jit
def compute_loss(param_values, param_names, data):
    """Example JIT-compiled loss function."""
    params = dict(zip(param_names, param_values))
    # ... compute negative log-likelihood or MSE
    return 0.0  # placeholder

# First call compiles; subsequent calls are fast
# loss = compute_loss(values, names, data)

Avoiding Recompilation

JIT recompiles when input shapes change. In jaxcmr:

  • Keep list lengths constant within a fitting run—pad shorter lists if needed

  • Use fixed recall counts by padding recall sequences to max length

  • Mark static arguments with static_argnums for values that shouldn’t trigger recompilation

Common Errors

ConcretizationTypeError: You tried to branch on a traced value inside JIT. Use jax.lax.cond instead of if.

Tracer leaked: A traced array escaped the JIT scope, often via global state or print statements.

Out of memory: Large batches exceed device memory. Process trials in smaller chunks or use jax.lax.scan.