Code
Data → Loss Function → Optimizer → Fitted Parameters → Simulation → ComparisonFitting, simulation, and model comparison
Model evaluation measures how well models explain observed data. jaxcmr provides a modular evaluation pipeline with interchangeable loss functions, optimizers, and comparison metrics.
simulate_h5_from_h5 || Comparison | Compare models | AICc, BIC, t-tests |## Component ArchitectureAll evaluation components follow protocols defined in jaxcmr/typing.py:# Loss function generator creates trial-specific loss functions
class LossFnGenerator(Protocol):
def __call__(
self,
trial_indices: Int[Array, "trials"],
parameters: dict[str, Array],
free_param_names: list[str],
) -> Callable[[Float[Array, "params"]], Float[Array, ""]]: ...
# Fitting algorithm finds optimal parameters
class FittingAlgorithm(Protocol):
def fit(
self,
trial_mask: Bool[Array, "subjects trials"],
) -> FitResult: ...See Protocols for complete definitions.
| Loss Function | Module | Description |
|---|---|---|
| Sequence Likelihood | loss.sequence_likelihood |
Standard negative log-likelihood |
| Transform Likelihood | loss.transform_sequence_likelihood |
Likelihood with pluggable masking |
| Set Permutation | loss.set_permutation_likelihood |
Monte Carlo for unordered recall |
| SPC MSE | loss.spc_mse |
Serial position curve fitting |
| Category SPC MSE | loss.cat_spc_mse |
Category-filtered SPC |
See Loss Functions for detailed documentation.
| Algorithm | Module | Description |
|---|---|---|
| ScipyDE | fitting |
Differential evolution via SciPy |
See Fitting for usage patterns.
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory
# 1. Create model factory
factory = make_factory(...)
# 2. Define parameters
fixed_params = {"allow_repeated_recalls": False, ...}
hyperparams = {
"bounds": {
"encoding_drift_rate": [0.0, 1.0],
"recall_drift_rate": [0.0, 1.0],
},
"num_steps": 1000,
}
# 3. Fit model
fitter = ScipyDE(
dataset=data,
features=None,
base_params=fixed_params,
model_factory=factory,
loss_fn_generator=MemorySearchLikelihoodFnGenerator,
hyperparams=hyperparams,
)
results = fitter.fit(trial_mask)
# 4. Simulate
from jaxcmr.simulation import simulate_h5_from_h5
from jax import random
sim = simulate_h5_from_h5(
factory, data, None, results["fits"],
trial_mask, experiment_count=100, rng=random.PRNGKey(0),
)
# 5. Compare to data
from jaxcmr.analyses.spc import plot_spc
plot_spc([data, sim], [trial_mask, trial_mask], ["Data", "Model"])Use Sequence Likelihood with ScipyDE: - Captures full temporal dynamics (order, transitions) - Maximum likelihood estimation - Per-subject parameter fitting
Use Set Permutation Likelihood: - When only the recall set is recorded - Final free recall paradigms - Monte Carlo handles factorial complexity
Use SPC MSE or Category SPC MSE: - Fits to aggregate curves - Useful when likelihood is computationally prohibitive - Ignores recall order information