Quickly instantiate the Monte Carlo bag-likelihood generator and inspect its outputs for a handful of trials.
The notebook mirrors the minimal setup from fitting.ipynb: load the Healey & Kahana (2014) dataset, build the CMR factory, initialize MemorySearchLikelihoodFnGenerator, and evaluate base_predict_trials on a small trial subset.
Code
import matplotlib.pyplot as plt
import os
import jax.numpy as jnp
from jax import random
from jaxcmr.loss.set_likelihood_monte_carlo import MemorySearchLikelihoodFnGenerator
from jaxcmr.math import lb
from jaxcmr.helpers import (
find_project_root,
generate_trial_mask,
import_from_string,
load_data,
)
def save_figure(figure_dir, figure_str, suffix= None ):
plt.tight_layout()
if not figure_str:
plt.show()
return
os.makedirs(figure_dir, exist_ok= True )
suffix_str = f"_ { suffix} " if suffix else ""
figure_path = os.path.join(figure_dir, f" { figure_str}{ suffix_str} .png" )
plt.savefig(figure_path, bbox_inches= "tight" , dpi= 600 )
plt.show()
Code
data_path = "data/HealeyKahana2014.h5"
figure_dir = "results/figures"
figure_str = ""
trial_query = "data['listtype'] == -1"
max_subjects = 1
trial_count = 16
model_factory_path = "jaxcmr.models.cmr.make_factory"
component_paths = {
"mfc_create_fn" : "jaxcmr.components.linear_memory.init_mfc" ,
"mcf_create_fn" : "jaxcmr.components.linear_memory.init_mcf" ,
"context_create_fn" : "jaxcmr.components.context.init" ,
"termination_policy_create_fn" : "jaxcmr.components.termination.NoStopTermination" ,
}
simulation_count = 100
seed = 0
base_params = {
"encoding_drift_rate" : 0.9 ,
"start_drift_rate" : 0.13 ,
"recall_drift_rate" : 0.8 ,
"primacy_scale" : 3.43 ,
"primacy_decay" : 1.43 ,
"learning_rate" : 0.4 ,
"choice_sensitivity" : 63.5 ,
"item_support" : 13. ,
"shared_support" : 13. ,
"allow_repeated_recalls" : False ,
"learn_after_context_update" : False ,
}
Code
project_root = find_project_root()
figure_dir = os.path.join(project_root, figure_dir)
data = load_data(os.path.join(project_root, data_path), max_subjects)
trial_mask = generate_trial_mask(data, trial_query)
trial_indices = jnp.where(trial_mask)[0 ][:trial_count]
trial_indices
Array([ 0, 1, 2, 3, 23, 25, 26, 30, 34, 39, 41, 45, 51, 55, 59, 63], dtype=int32)
Code
make_factory = import_from_string(model_factory_path)
component_fns = {name: import_from_string(path) for name, path in component_paths.items()}
factory_cls = make_factory(** component_fns)
loss_generator = MemorySearchLikelihoodFnGenerator(factory_cls, data, None )
loss_generator.simulation_count = simulation_count
loss_generator.base_key = random.PRNGKey(seed)
loss_generator.trial_keys = random.split(loss_generator.base_key, data["recalls" ].shape[0 ])
loss_generator
<jaxcmr.compare_simulation_loss.MemorySearchLikelihoodFnGenerator at 0x3378e8b90>
Code
probabilities = loss_generator.base_predict_trials(trial_indices, base_params)
probabilities
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
Code
loss_value = - jnp.sum (jnp.log(probabilities + lb))
print (f"Monte Carlo negative log-likelihood: { float (loss_value):.6f} " )
Monte Carlo negative log-likelihood: 255.078156