jaxcmr uses JAX for fast, reproducible simulations. This page covers jaxcmr-specific patterns. For JAX fundamentals, see the JAX Quickstart and Sharp Bits.
Why JAX for Memory Modeling?
Memory models require:
Fast simulation — Thousands of trials to evaluate a parameter set
Model methods return new instances rather than mutating state, enabling safe use inside JIT.
Batching Trials with vmap
Use vmap to simulate multiple trials in parallel. Each trial needs its own random key:
Code
def simulate_single_trial(key, presentations):"""Simulate one trial given item presentations.""" model = CMR(list_length=len(presentations), parameters=params)for item_id in presentations: model = model.experience(item_id) model = model.start_retrieving()return model.outcome_probabilities()# Batch of 100 trials, each with 16 itemsimport jax.numpy as jnpn_trials =100keys = jax.random.split(jax.random.PRNGKey(0), n_trials)presentations = jnp.tile(jnp.arange(1, 17), (n_trials, 1))# Vectorize across trialsbatch_simulate = jax.vmap(simulate_single_trial)all_probs = batch_simulate(keys, presentations)print(f"Shape: {all_probs.shape} # (trials, outcomes)")
JIT for Fitting
Wrap loss functions in jit for fast parameter optimization:
Code
@jax.jitdef compute_loss(param_values, param_names, data):"""Example JIT-compiled loss function.""" params =dict(zip(param_names, param_values))# ... compute negative log-likelihood or MSEreturn0.0# placeholder# First call compiles; subsequent calls are fast# loss = compute_loss(values, names, data)
Avoiding Recompilation
JIT recompiles when input shapes change. In jaxcmr:
Keep list lengths constant within a fitting run—pad shorter lists if needed
Use fixed recall counts by padding recall sequences to max length
Mark static arguments with static_argnums for values that shouldn’t trigger recompilation
Common Errors
ConcretizationTypeError: You tried to branch on a traced value inside JIT. Use jax.lax.cond instead of if.
Tracer leaked: A traced array escaped the JIT scope, often via global state or print statements.
Out of memory: Large batches exceed device memory. Process trials in smaller chunks or use jax.lax.scan.