Quickly instantiate the Monte Carlo bag-likelihood generator and inspect its outputs for a handful of trials.

The notebook mirrors the minimal setup from fitting.ipynb: load the Healey & Kahana (2014) dataset, build the CMR factory, initialize MemorySearchLikelihoodFnGenerator, and evaluate base_predict_trials on a small trial subset.

Code
import matplotlib.pyplot as plt
import os

import jax.numpy as jnp
from jax import random

from jaxcmr.loss.set_likelihood_monte_carlo import MemorySearchLikelihoodFnGenerator
from jaxcmr.math import lb
from jaxcmr.helpers import (
    find_project_root,
    generate_trial_mask,
    import_from_string,
    load_data,
)
def save_figure(figure_dir, figure_str, suffix=None):
    plt.tight_layout()
    if not figure_str:
        plt.show()
        return
    os.makedirs(figure_dir, exist_ok=True)
    suffix_str = f"_{suffix}" if suffix else ""
    figure_path = os.path.join(figure_dir, f"{figure_str}{suffix_str}.png")
    plt.savefig(figure_path, bbox_inches="tight", dpi=600)
    plt.show()
Code
data_path = "data/HealeyKahana2014.h5"
figure_dir = "results/figures"
figure_str = ""
trial_query = "data['listtype'] == -1"
max_subjects = 1
trial_count = 16

model_factory_path = "jaxcmr.models.cmr.make_factory"
component_paths = {
    "mfc_create_fn": "jaxcmr.components.linear_memory.init_mfc",
    "mcf_create_fn": "jaxcmr.components.linear_memory.init_mcf",
    "context_create_fn": "jaxcmr.components.context.init",
    "termination_policy_create_fn": "jaxcmr.components.termination.NoStopTermination",
}

simulation_count = 100
seed = 0

base_params = {
    "encoding_drift_rate": 0.9,
    "start_drift_rate": 0.13,
    "recall_drift_rate": 0.8,
    "primacy_scale": 3.43,
    "primacy_decay": 1.43,
    "learning_rate": 0.4,
    "choice_sensitivity": 63.5,
    "item_support": 13.,
    "shared_support": 13.,
    "allow_repeated_recalls": False,
    "learn_after_context_update": False,
}
Code
project_root = find_project_root()
figure_dir = os.path.join(project_root, figure_dir)
data = load_data(os.path.join(project_root, data_path), max_subjects)
trial_mask = generate_trial_mask(data, trial_query)
trial_indices = jnp.where(trial_mask)[0][:trial_count]
trial_indices
Array([ 0,  1,  2,  3, 23, 25, 26, 30, 34, 39, 41, 45, 51, 55, 59, 63],      dtype=int32)
Code
make_factory = import_from_string(model_factory_path)
component_fns = {name: import_from_string(path) for name, path in component_paths.items()}
factory_cls = make_factory(**component_fns)

loss_generator = MemorySearchLikelihoodFnGenerator(factory_cls, data, None)
loss_generator.simulation_count = simulation_count
loss_generator.base_key = random.PRNGKey(seed)
loss_generator.trial_keys = random.split(loss_generator.base_key, data["recalls"].shape[0])
loss_generator
<jaxcmr.compare_simulation_loss.MemorySearchLikelihoodFnGenerator at 0x3378e8b90>
Code
probabilities = loss_generator.base_predict_trials(trial_indices, base_params)
probabilities
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
Code
loss_value = -jnp.sum(jnp.log(probabilities + lb))
print(f"Monte Carlo negative log-likelihood: {float(loss_value):.6f}")
Monte Carlo negative log-likelihood: 255.078156