jaxcmr expects data in a specific format defined by the RecallDataset TypedDict.
RecallDataset Structure
Code
from nbdev.showdoc import show_docfrom jaxcmr.typing import RecallDatasetshow_doc(RecallDataset)
RecallDataset
*A typed dictionary representing a dataset for free or serial recall experiments. Each key maps to a 2D integer array of shape (n_trials, ?). Rows correspond to trials; columns vary by field. Zeros are used to indicate unused or padding entries, with values starting from 1.
Required fields:
subject: Subject IDs (one per trial).
listLength: The length of the list presented in each trial.
pres_itemids: Cross-list item IDs presented in each trial (points to a global word pool).
Before analysis, verify your data meets the expected format:
Code
def validate_dataset(dataset):"""Check that a dataset has the required fields and valid values.""" required = ['subject', 'listLength', 'pres_itemnos', 'recalls']# Check required fields exist missing = [f for f in required if f notin dataset]if missing:raiseValueError(f"Missing required fields: {missing}") n_trials = dataset['subject'].shape[0]# Check shapes are consistentfor field in required:if dataset[field].shape[0] != n_trials:raiseValueError(f"Field '{field}' has inconsistent trial count")# Check indexing convention (1-indexed, 0 for padding) list_lengths = dataset['listLength'].squeeze()for i inrange(min(10, n_trials)): # Spot check first 10 trials ll = list_lengths[i] recalls = dataset['recalls'][i] valid_recalls = recalls[(recalls >0) & (recalls <= ll)]iflen(valid_recalls) != np.sum(recalls >0):print(f"Warning: Trial {i} has recall values outside [1, {ll}]")print(f"Dataset validated: {n_trials} trials")returnTruevalidate_dataset(dataset)
Dataset validated: 3600 trials
True
Converting from CSV
If your data is in CSV format, convert it to the expected structure:
Code
import h5pyimport pandas as pddef csv_to_hdf5(csv_path, output_path, list_length, max_recalls=None):""" Convert a CSV with columns [subject, trial, position, recalled] to HDF5. Expected CSV format: subject,trial,position,recalled 1,1,1,1 1,1,2,0 ... Where 'recalled' indicates recall order (0 = not recalled, 1+ = order). """ df = pd.read_csv(csv_path)# Group by subject and trial grouped = df.groupby(['subject', 'trial']) n_trials =len(grouped) max_recalls = max_recalls or list_length# Initialize arrays subjects = np.zeros((n_trials, 1), dtype=np.int32) list_lengths = np.full((n_trials, 1), list_length, dtype=np.int32) pres_itemnos = np.zeros((n_trials, list_length), dtype=np.int32) recalls = np.zeros((n_trials, max_recalls), dtype=np.int32)for i, ((subj, trial), group) inenumerate(grouped): subjects[i, 0] = subj pres_itemnos[i] = np.arange(1, list_length +1)# Extract recall sequence recalled = group[group['recalled'] >0].sort_values('recalled') recall_positions = recalled['position'].values recalls[i, :len(recall_positions)] = recall_positions# Save to HDF5with h5py.File(output_path, 'w') as f: f.create_dataset('subject', data=subjects) f.create_dataset('listLength', data=list_lengths) f.create_dataset('pres_itemnos', data=pres_itemnos) f.create_dataset('pres_itemids', data=pres_itemnos) # Same if no word pool f.create_dataset('recalls', data=recalls)print(f"Saved {n_trials} trials to {output_path}")# Example usage (uncomment to run):# csv_to_hdf5('my_experiment.csv', 'my_experiment.h5', list_length=16)
Available Datasetsjaxcmr includes several classic free recall datasets in the data/ directory:| Dataset | Description ||———|————-|| HealeyKahana2014.h5 | Large-scale free recall study || HowardKahana2005.h5 | Free recall with varying list lengths || LohnasKahana2014.h5 | Free recall with repetitions |