Fitting, simulation, and model comparison

Model evaluation measures how well models explain observed data. jaxcmr provides a modular evaluation pipeline with interchangeable loss functions, optimizers, and comparison metrics.

Evaluation Pipeline

Code
Data → Loss Function → Optimizer → Fitted Parameters → Simulation → Comparison
Stage | Purpose | Key Components ||——-|———|—————-|| Loss | Quantify model-data discrepancy | LossFnGenerator protocol || Fitting | Find optimal parameters | FittingAlgorithm protocol || Simulation | Generate predictions | simulate_h5_from_h5 || Comparison | Compare models | AICc, BIC, t-tests |## Component ArchitectureAll evaluation components follow protocols defined in jaxcmr/typing.py:
Code
# Loss function generator creates trial-specific loss functions
class LossFnGenerator(Protocol):
    def __call__(
        self,
        trial_indices: Int[Array, "trials"],
        parameters: dict[str, Array],
        free_param_names: list[str],
    ) -> Callable[[Float[Array, "params"]], Float[Array, ""]]: ...

# Fitting algorithm finds optimal parameters
class FittingAlgorithm(Protocol):
    def fit(
        self,
        trial_mask: Bool[Array, "subjects trials"],
    ) -> FitResult: ...

See Protocols for complete definitions.

Loss Functions

Loss Function Module Description
Sequence Likelihood loss.sequence_likelihood Standard negative log-likelihood
Transform Likelihood loss.transform_sequence_likelihood Likelihood with pluggable masking
Set Permutation loss.set_permutation_likelihood Monte Carlo for unordered recall
SPC MSE loss.spc_mse Serial position curve fitting
Category SPC MSE loss.cat_spc_mse Category-filtered SPC

See Loss Functions for detailed documentation.

Fitting Algorithms

Algorithm Module Description
ScipyDE fitting Differential evolution via SciPy

See Fitting for usage patterns.

Quick Start

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory

# 1. Create model factory
factory = make_factory(...)

# 2. Define parameters
fixed_params = {"allow_repeated_recalls": False, ...}
hyperparams = {
    "bounds": {
        "encoding_drift_rate": [0.0, 1.0],
        "recall_drift_rate": [0.0, 1.0],
    },
    "num_steps": 1000,
}

# 3. Fit model
fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params=fixed_params,
    model_factory=factory,
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,
    hyperparams=hyperparams,
)

results = fitter.fit(trial_mask)

# 4. Simulate
from jaxcmr.simulation import simulate_h5_from_h5
from jax import random

sim = simulate_h5_from_h5(
    factory, data, None, results["fits"],
    trial_mask, experiment_count=100, rng=random.PRNGKey(0),
)

# 5. Compare to data
from jaxcmr.analyses.spc import plot_spc
plot_spc([data, sim], [trial_mask, trial_mask], ["Data", "Model"])

Section Contents

  • Protocols — LossFnGenerator and FittingAlgorithm protocols
  • Loss Functions
  • Fitting
    • Overview — Fitting algorithms and patterns
    • ScipyDE — Differential evolution optimizer
  • Simulation — Generating predictions from fitted parameters
  • Model Comparison — AICc, t-tests, winner analysis, and Bayesian model selection

Choosing an Approach

For Standard Free Recall

Use Sequence Likelihood with ScipyDE: - Captures full temporal dynamics (order, transitions) - Maximum likelihood estimation - Per-subject parameter fitting

For Unordered Recall

Use Set Permutation Likelihood: - When only the recall set is recorded - Final free recall paradigms - Monte Carlo handles factorial complexity

For Shape Matching

Use SPC MSE or Category SPC MSE: - Fits to aggregate curves - Useful when likelihood is computationally prohibitive - Ignores recall order information