Computing basic recall metrics

Before computing detailed analyses, you often need basic metrics about recall performance.

Basic Metrics

Recall Count

Number of correct recalls per trial:

Code
import jax.numpy as jnp

def recall_count(recalls):
    """Count non-zero (valid) recalls per trial."""
    return jnp.sum(recalls > 0, axis=1)

counts = recall_count(dataset["recalls"])

Recall Probability

Proportion of items recalled:

Code
def recall_probability(recalls, list_length):
    """Compute recall probability per trial."""
    counts = jnp.sum(recalls > 0, axis=1)
    return counts / list_length.squeeze()

Filtering Trials

You may want to filter trials by subject, session, or other criteria:

Code
# Trials for subject 1
mask = dataset["subject"].squeeze() == 1
subject_recalls = dataset["recalls"][mask]

Next StepsWith data loaded and validated, you can compute:- Serial Position Curve - Recall by presentation position- Conditional Response Probability - Temporal contiguity