Code
import jax.numpy as jnp
def recall_count(recalls):
"""Count non-zero (valid) recalls per trial."""
return jnp.sum(recalls > 0, axis=1)
counts = recall_count(dataset["recalls"])Computing basic recall metrics
Before computing detailed analyses, you often need basic metrics about recall performance.
Number of correct recalls per trial:
Proportion of items recalled:
You may want to filter trials by subject, session, or other criteria: