Code
# Pseudocode
for params in optimization:
simulated_recalls = simulate(model, params)
y_sim = analysis_function(simulated_recalls)
loss = mean((y_obs - y_sim) ** 2)Fitting to summary statistics via simulation
MSE-based loss functions compare observed summary statistics to simulated predictions, measuring fit via mean squared error. This approach fits to aggregate behavioral patterns rather than raw recall sequences.
Sometimes you want to fit to shapes rather than sequences: - Match the serial position curve profile - Fit to category-specific recall rates - Target specific behavioral measures
MSE-based fitting: - Directly optimizes the statistic of interest - Works when likelihood is intractable - Captures aggregate patterns
The MSE loss compares observed and simulated analysis results:
\[L(\theta) = \frac{1}{n} \sum_{i=1}^{n} \left( y_i^{obs} - y_i^{sim}(\theta) \right)^2\]
Where: - \(y^{obs}\) = observed statistic (e.g., SPC values) - \(y^{sim}(\theta)\) = simulated statistic from model with parameters \(\theta\) - \(n\) = number of data points in the statistic
Since \(y^{sim}(\theta)\) requires sampling from the model:
Module: jaxcmr.loss.spc_mse
Fits to the Serial Position Curve (probability of recall by presentation position).
Mathematical Specification:
The SPC is defined as: \[SPC_i = \frac{1}{T} \sum_{t=1}^{T} \mathbf{1}[i \in R_t]\]
where \(\mathbf{1}[i \in R_t]\) indicates whether item at position \(i\) was recalled in trial \(t\).
The loss: \[L_{SPC}(\theta) = \frac{1}{L} \sum_{i=1}^{L} \left( SPC_i^{obs} - SPC_i^{sim}(\theta) \right)^2\]
Parameters:
| Parameter | Default | Description |
|---|---|---|
simulation_count |
20 | Simulations per trial |
Usage:
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator
fitter = ScipyDE(
dataset=data,
features=None,
base_params=fixed_params,
model_factory=model_factory,
loss_fn_generator=MemorySearchSpcMseFnGenerator,
hyperparams={"bounds": free_params, "num_steps": 1000},
)
results = fitter.fit(trial_mask)Module: jaxcmr.loss.cat_spc_mse
Fits to category-filtered Serial Position Curves—separate SPCs for each stimulus category.
Mathematical Specification:
Category-specific SPC: \[SPC_{c,i} = \frac{1}{T_c} \sum_{t: cat_t = c} \mathbf{1}[i \in R_t]\]
where \(T_c\) is the number of trials in category \(c\).
The loss: \[L_{cat}(\theta) = \frac{1}{C \cdot L} \sum_{c=1}^{C} \sum_{i=1}^{L} \left( SPC_{c,i}^{obs} - SPC_{c,i}^{sim}(\theta) \right)^2\]
Parameters:
| Parameter | Default | Description |
|---|---|---|
simulation_count |
10 | Simulations per trial |
category_values |
[1, 2] |
Category codes in dataset |
Data Requirement:
The dataset must include a condition field indicating item categories:
Usage:
from jaxcmr.loss.cat_spc_mse import MemorySearchCatSpcMseFnGenerator
fitter = ScipyDE(
dataset=data, # Must have 'condition' field
features=None,
base_params=fixed_params,
model_factory=model_factory,
loss_fn_generator=MemorySearchCatSpcMseFnGenerator,
hyperparams={"bounds": free_params, "num_steps": 1000},
)Module: jaxcmr.experimental.mse_loss
Fits to any analysis function you provide.
Parameters:
| Parameter | Default | Description |
|---|---|---|
simulation_count |
20 | Simulations per trial |
analysis_fn |
Required | RecallAnalysisFn callback |
Analysis Function Protocol:
Usage:
from jaxcmr.experimental.mse_loss import MemorySearchMseFnGenerator
def custom_analysis(recalls, pres):
"""Example: compute mean number of recalls per trial."""
return jnp.mean(jnp.sum(recalls > 0, axis=1))
# Create generator with custom analysis
generator = MemorySearchMseFnGenerator(
model_create_fn, dataset, features,
analysis_fn=custom_analysis,
)| Variant | Target Statistic | simulation_count |
Data Requirements |
|---|---|---|---|
| SPC MSE | Serial position curve | 20 | Standard |
| Category SPC MSE | Category-filtered SPC | 10 | condition field |
| General MSE | User-defined | 20 | Depends on analysis |
MSE losses use simulation-based gradients: - Each evaluation samples from the model - Gradients are noisy (Monte Carlo) - May require more iterations than likelihood
Higher simulation_count |
Lower simulation_count |
|---|---|
| Lower variance | Higher variance |
| Slower per evaluation | Faster per evaluation |
| Smoother optimization | Noisier optimization |
Typical values: 10-50 depending on task complexity.
Due to stochastic gradients: - Consider using more optimizer iterations - Use best_of > 1 for multiple restarts - Monitor convergence carefully
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator
from jaxcmr.models.cmr import make_factory
from jaxcmr.components.linear_memory import init_mfc, init_mcf
from jaxcmr.components.context import init as init_context
from jaxcmr.components.termination import PositionalTermination
# Create model factory
model_factory = make_factory(
init_mfc, init_mcf, init_context, PositionalTermination
)
# Define parameters
fixed_params = {
"allow_repeated_recalls": False,
"learn_after_context_update": True,
}
free_params = {
"encoding_drift_rate": [0.0, 1.0],
"start_drift_rate": [0.0, 1.0],
"recall_drift_rate": [0.0, 1.0],
"primacy_scale": [0.0, 10.0],
"primacy_decay": [0.0, 5.0],
}
# Fit to SPC shape
fitter = ScipyDE(
dataset=data,
features=None,
base_params=fixed_params,
model_factory=model_factory,
loss_fn_generator=MemorySearchSpcMseFnGenerator,
hyperparams={
"bounds": free_params,
"num_steps": 1500, # More iterations for MSE
"best_of": 3,
},
)
results = fitter.fit(trial_mask)After fitting, compare observed and simulated SPCs:
from jaxcmr.simulation import simulate_h5_from_h5
from jaxcmr.analyses.spc import plot_spc
# Simulate from fitted parameters
sim = simulate_h5_from_h5(
model_factory, data, None,
results["fits"], trial_mask,
experiment_count=100, rng=key,
)
# Compare
plot_spc(
datasets=[data, sim],
trial_masks=[trial_mask, trial_mask],
labels=["Data", "Model"],
)MSE fitting doesn’t produce likelihood values needed for: - AIC/BIC model comparison - Bayes factors - Likelihood ratio tests
For model comparison with MSE, use cross-validation or predictive accuracy.
Each evaluation is stochastic: - Same parameters give slightly different losses - Optimization landscape is noisy - May converge to different optima across runs
The choice of summary statistic affects what the model learns: - SPC fitting may ignore contiguity - Category SPC may ignore within-category dynamics - Choose statistics that capture theoretically important patterns