Conditional Distance-CRP

Compute Distance-CRP while selectively excluding transitions via a tabulation mask.

The conditional Distance-CRP extends the standard Distance-CRP by adding a _should_tabulate mask that controls which transitions are counted. Distance bins are computed from cosine similarity of semantic embeddings. When the mask is False, internal state updates but tallies are not incremented.

Workflow

Code
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np
from jaxcmr.analyses.conditional_distcrp import plot_dist_crp
from jaxcmr.helpers import find_project_root, generate_trial_mask, load_data, save_figure

warnings.filterwarnings("ignore")
Code
data_path = "data/HealeyKahana2014.h5"
features_path = "data/peers-all-mpnet-base-v2.npy"
figure_dir = "results/figures"
figure_str = ""
ylim = None
trial_query = "data['listtype'] == -1"
Code
project_root = find_project_root()
figure_dir = os.path.join(project_root, figure_dir)
data_path = os.path.join(project_root, data_path)
features_path = os.path.join(project_root, features_path)
data = load_data(data_path)
data["_should_tabulate"] = data["recalls"] > 0
features = np.load(features_path).astype(np.float32)
trial_mask = generate_trial_mask(data, trial_query)
Code
plot_dist_crp(data, trial_mask, features=features)
if ylim is not None:
    for ax in plt.gcf().axes:
        ax.set_ylim(ylim)
save_figure(figure_dir, figure_str)

Interpretation

The x-axis shows semantic distance bins; the y-axis shows conditional transition probability.

  • Downward slope: semantic contiguity — transitions favor semantically similar items.
  • Flat curve: semantic similarity does not predict transitions.
  • Compare with standard Distance-CRP to assess the effect of excluding masked transitions.

API Details

Notebook parameters

  • data_path — path to an HDF5 file containing a RecallDataset.
  • features_path — path to a numpy file containing semantic embeddings (one row per item in the pool).
  • figure_dir — directory for saving figures.
  • figure_str — base filename for the saved figure. Leave empty to display without saving.
  • ylim — y-axis limits as a tuple, or None for automatic scaling.
  • trial_query — a Python expression evaluated against the dataset to select trials.

The _should_tabulate mask is constructed in the load cell as data["recalls"] > 0. Modify this expression to change which transitions are excluded.