MSE-Based Loss Functions

Fitting to summary statistics via simulation

MSE-based loss functions compare observed summary statistics to simulated predictions, measuring fit via mean squared error. This approach fits to aggregate behavioral patterns rather than raw recall sequences.

Why MSE?

Sometimes you want to fit to shapes rather than sequences: - Match the serial position curve profile - Fit to category-specific recall rates - Target specific behavioral measures

MSE-based fitting: - Directly optimizes the statistic of interest - Works when likelihood is intractable - Captures aggregate patterns

General Framework

The MSE loss compares observed and simulated analysis results:

\[L(\theta) = \frac{1}{n} \sum_{i=1}^{n} \left( y_i^{obs} - y_i^{sim}(\theta) \right)^2\]

Where: - \(y^{obs}\) = observed statistic (e.g., SPC values) - \(y^{sim}(\theta)\) = simulated statistic from model with parameters \(\theta\) - \(n\) = number of data points in the statistic

Simulation-Based Estimation

Since \(y^{sim}(\theta)\) requires sampling from the model:

  1. Simulate: Generate recall chains from the model
  2. Analyze: Compute the target statistic on simulated data
  3. Compare: Calculate MSE between observed and simulated
Code
# Pseudocode
for params in optimization:
    simulated_recalls = simulate(model, params)
    y_sim = analysis_function(simulated_recalls)
    loss = mean((y_obs - y_sim) ** 2)

Variants

SPC MSE

Module: jaxcmr.loss.spc_mse

Fits to the Serial Position Curve (probability of recall by presentation position).

Code
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator

Mathematical Specification:

The SPC is defined as: \[SPC_i = \frac{1}{T} \sum_{t=1}^{T} \mathbf{1}[i \in R_t]\]

where \(\mathbf{1}[i \in R_t]\) indicates whether item at position \(i\) was recalled in trial \(t\).

The loss: \[L_{SPC}(\theta) = \frac{1}{L} \sum_{i=1}^{L} \left( SPC_i^{obs} - SPC_i^{sim}(\theta) \right)^2\]

Parameters:

Parameter Default Description
simulation_count 20 Simulations per trial

Usage:

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator

fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params=fixed_params,
    model_factory=model_factory,
    loss_fn_generator=MemorySearchSpcMseFnGenerator,
    hyperparams={"bounds": free_params, "num_steps": 1000},
)

results = fitter.fit(trial_mask)

Category SPC MSE

Module: jaxcmr.loss.cat_spc_mse

Fits to category-filtered Serial Position Curves—separate SPCs for each stimulus category.

Code
from jaxcmr.loss.cat_spc_mse import MemorySearchCatSpcMseFnGenerator

Mathematical Specification:

Category-specific SPC: \[SPC_{c,i} = \frac{1}{T_c} \sum_{t: cat_t = c} \mathbf{1}[i \in R_t]\]

where \(T_c\) is the number of trials in category \(c\).

The loss: \[L_{cat}(\theta) = \frac{1}{C \cdot L} \sum_{c=1}^{C} \sum_{i=1}^{L} \left( SPC_{c,i}^{obs} - SPC_{c,i}^{sim}(\theta) \right)^2\]

Parameters:

Parameter Default Description
simulation_count 10 Simulations per trial
category_values [1, 2] Category codes in dataset

Data Requirement:

The dataset must include a condition field indicating item categories:

Code
dataset["condition"]  # Shape: (trials, items), values in category_values

Usage:

Code
from jaxcmr.loss.cat_spc_mse import MemorySearchCatSpcMseFnGenerator

fitter = ScipyDE(
    dataset=data,  # Must have 'condition' field
    features=None,
    base_params=fixed_params,
    model_factory=model_factory,
    loss_fn_generator=MemorySearchCatSpcMseFnGenerator,
    hyperparams={"bounds": free_params, "num_steps": 1000},
)

General MSE

Module: jaxcmr.experimental.mse_loss

Fits to any analysis function you provide.

Code
from jaxcmr.experimental.mse_loss import MemorySearchMseFnGenerator

Parameters:

Parameter Default Description
simulation_count 20 Simulations per trial
analysis_fn Required RecallAnalysisFn callback

Analysis Function Protocol:

Code
def my_analysis(
    recalls: Integer[Array, "trials recall_events"],
    presentations: Integer[Array, "trials study_events"],
) -> Float[Array, "..."]:
    """Compute summary statistic from recall data."""
    # Your analysis here
    return statistic_array

Usage:

Code
from jaxcmr.experimental.mse_loss import MemorySearchMseFnGenerator

def custom_analysis(recalls, pres):
    """Example: compute mean number of recalls per trial."""
    return jnp.mean(jnp.sum(recalls > 0, axis=1))

# Create generator with custom analysis
generator = MemorySearchMseFnGenerator(
    model_create_fn, dataset, features,
    analysis_fn=custom_analysis,
)

Comparison Table

Variant Target Statistic simulation_count Data Requirements
SPC MSE Serial position curve 20 Standard
Category SPC MSE Category-filtered SPC 10 condition field
General MSE User-defined 20 Depends on analysis

Computational Considerations

Gradient Estimation

MSE losses use simulation-based gradients: - Each evaluation samples from the model - Gradients are noisy (Monte Carlo) - May require more iterations than likelihood

Simulation Count Trade-off

Higher simulation_count Lower simulation_count
Lower variance Higher variance
Slower per evaluation Faster per evaluation
Smoother optimization Noisier optimization

Typical values: 10-50 depending on task complexity.

Optimization Stability

Due to stochastic gradients: - Consider using more optimizer iterations - Use best_of > 1 for multiple restarts - Monitor convergence carefully

When to Use MSE vs Likelihood

Prefer MSE When:

  • Shape matching: Primary goal is matching aggregate patterns
  • Specific statistics: Testing predictions for particular measures
  • Likelihood intractable: Model doesn’t provide easy probability computation
  • Robustness: Less sensitive to outlier trials

Prefer Likelihood When:

  • Full information: Want to use all temporal dynamics
  • Model comparison: AIC/BIC require likelihood values
  • Statistical inference: Confidence intervals, hypothesis tests
  • Efficiency: Likelihood often converges faster

Usage Example

SPC MSE Fitting

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator
from jaxcmr.models.cmr import make_factory
from jaxcmr.components.linear_memory import init_mfc, init_mcf
from jaxcmr.components.context import init as init_context
from jaxcmr.components.termination import PositionalTermination

# Create model factory
model_factory = make_factory(
    init_mfc, init_mcf, init_context, PositionalTermination
)

# Define parameters
fixed_params = {
    "allow_repeated_recalls": False,
    "learn_after_context_update": True,
}

free_params = {
    "encoding_drift_rate": [0.0, 1.0],
    "start_drift_rate": [0.0, 1.0],
    "recall_drift_rate": [0.0, 1.0],
    "primacy_scale": [0.0, 10.0],
    "primacy_decay": [0.0, 5.0],
}

# Fit to SPC shape
fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params=fixed_params,
    model_factory=model_factory,
    loss_fn_generator=MemorySearchSpcMseFnGenerator,
    hyperparams={
        "bounds": free_params,
        "num_steps": 1500,  # More iterations for MSE
        "best_of": 3,
    },
)

results = fitter.fit(trial_mask)

Verifying the Fit

After fitting, compare observed and simulated SPCs:

Code
from jaxcmr.simulation import simulate_h5_from_h5
from jaxcmr.analyses.spc import plot_spc

# Simulate from fitted parameters
sim = simulate_h5_from_h5(
    model_factory, data, None,
    results["fits"], trial_mask,
    experiment_count=100, rng=key,
)

# Compare
plot_spc(
    datasets=[data, sim],
    trial_masks=[trial_mask, trial_mask],
    labels=["Data", "Model"],
)

Limitations

No Likelihood Values

MSE fitting doesn’t produce likelihood values needed for: - AIC/BIC model comparison - Bayes factors - Likelihood ratio tests

For model comparison with MSE, use cross-validation or predictive accuracy.

Simulation Noise

Each evaluation is stochastic: - Same parameters give slightly different losses - Optimization landscape is noisy - May converge to different optima across runs

Target Selection

The choice of summary statistic affects what the model learns: - SPC fitting may ignore contiguity - Category SPC may ignore within-category dynamics - Choose statistics that capture theoretically important patterns