The RecallDataset structure

jaxcmr expects data in a specific format defined by the RecallDataset TypedDict.

RecallDataset Structure

Code
from nbdev.showdoc import show_doc
from jaxcmr.typing import RecallDataset

show_doc(RecallDataset)

RecallDataset

*A typed dictionary representing a dataset for free or serial recall experiments. Each key maps to a 2D integer array of shape (n_trials, ?). Rows correspond to trials; columns vary by field. Zeros are used to indicate unused or padding entries, with values starting from 1.

Required fields:

  • subject: Subject IDs (one per trial).
  • listLength: The length of the list presented in each trial.
  • pres_itemids: Cross-list item IDs presented in each trial (points to a global word pool).
  • pres_itemnos: Within-list item numbers (1-based indices; 0 indicates padding).
  • rec_itemids: Cross-list item IDs corresponding to items recalled.
  • recalls: Within-list item numbers for recalled items (1-based indices; 0 indicates padding).

You can add as many as needed, with NotRequired[...].*

Indexing Convention

All item indices are 1-indexed:

  • Position 1 = first item in the list

  • Position 0 = padding (unused entry)

Loading Data

Use load_data from jaxcmr.helpers which handles HDF5 structure automatically:

Code
from jaxcmr.helpers import find_project_root, load_dataimport os# Load a datasetproject_root = find_project_root()data_path = os.path.join(project_root, "data/HealeyKahana2014.h5")dataset = load_data(data_path)# Inspect the structureprint("Available fields:")for key, arr in dataset.items():    print(f"  {key}: shape={arr.shape}, dtype={arr.dtype}")
Code
import numpy as np

# Examine a single trial
trial_idx = 0
print(f"\nTrial {trial_idx}:")
print(f"  Subject: {dataset['subject'][trial_idx, 0]}")
print(f"  List length: {dataset['listLength'][trial_idx, 0]}")
print(f"  Presented items: {dataset['pres_itemnos'][trial_idx]}")
print(f"  Recalls: {dataset['recalls'][trial_idx]}")

# Count non-padding recalls
recalls = dataset['recalls'][trial_idx]
n_recalled = np.sum(recalls > 0)
print(f"  Items recalled: {n_recalled}")

Trial 0:
  Subject: 1
  List length: 20
  Presented items: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
  Recalls: [20 19 13 18  1  9  2 17 16  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
  Items recalled: 9

Validating Data

Before analysis, verify your data meets the expected format:

Code
def validate_dataset(dataset):
    """Check that a dataset has the required fields and valid values."""
    required = ['subject', 'listLength', 'pres_itemnos', 'recalls']
    
    # Check required fields exist
    missing = [f for f in required if f not in dataset]
    if missing:
        raise ValueError(f"Missing required fields: {missing}")
    
    n_trials = dataset['subject'].shape[0]
    
    # Check shapes are consistent
    for field in required:
        if dataset[field].shape[0] != n_trials:
            raise ValueError(f"Field '{field}' has inconsistent trial count")
    
    # Check indexing convention (1-indexed, 0 for padding)
    list_lengths = dataset['listLength'].squeeze()
    for i in range(min(10, n_trials)):  # Spot check first 10 trials
        ll = list_lengths[i]
        recalls = dataset['recalls'][i]
        valid_recalls = recalls[(recalls > 0) & (recalls <= ll)]
        if len(valid_recalls) != np.sum(recalls > 0):
            print(f"Warning: Trial {i} has recall values outside [1, {ll}]")
    
    print(f"Dataset validated: {n_trials} trials")
    return True

validate_dataset(dataset)
Dataset validated: 3600 trials
True

Converting from CSV

If your data is in CSV format, convert it to the expected structure:

Code
import h5py
import pandas as pd

def csv_to_hdf5(csv_path, output_path, list_length, max_recalls=None):
    """
    Convert a CSV with columns [subject, trial, position, recalled] to HDF5.
    
    Expected CSV format:
        subject,trial,position,recalled
        1,1,1,1
        1,1,2,0
        ...
    Where 'recalled' indicates recall order (0 = not recalled, 1+ = order).
    """
    df = pd.read_csv(csv_path)
    
    # Group by subject and trial
    grouped = df.groupby(['subject', 'trial'])
    n_trials = len(grouped)
    max_recalls = max_recalls or list_length
    
    # Initialize arrays
    subjects = np.zeros((n_trials, 1), dtype=np.int32)
    list_lengths = np.full((n_trials, 1), list_length, dtype=np.int32)
    pres_itemnos = np.zeros((n_trials, list_length), dtype=np.int32)
    recalls = np.zeros((n_trials, max_recalls), dtype=np.int32)
    
    for i, ((subj, trial), group) in enumerate(grouped):
        subjects[i, 0] = subj
        pres_itemnos[i] = np.arange(1, list_length + 1)
        
        # Extract recall sequence
        recalled = group[group['recalled'] > 0].sort_values('recalled')
        recall_positions = recalled['position'].values
        recalls[i, :len(recall_positions)] = recall_positions
    
    # Save to HDF5
    with h5py.File(output_path, 'w') as f:
        f.create_dataset('subject', data=subjects)
        f.create_dataset('listLength', data=list_lengths)
        f.create_dataset('pres_itemnos', data=pres_itemnos)
        f.create_dataset('pres_itemids', data=pres_itemnos)  # Same if no word pool
        f.create_dataset('recalls', data=recalls)
    
    print(f"Saved {n_trials} trials to {output_path}")

# Example usage (uncomment to run):
# csv_to_hdf5('my_experiment.csv', 'my_experiment.h5', list_length=16)

Available Datasetsjaxcmr includes several classic free recall datasets in the data/ directory:| Dataset | Description ||———|————-|| HealeyKahana2014.h5 | Large-scale free recall study || HowardKahana2005.h5 | Free recall with varying list lengths || LohnasKahana2014.h5 | Free recall with repetitions |