jaxcmr analyses follow consistent patterns. You can create custom analyses that integrate with the library.
Analysis Pattern
Most analyses follow this structure:
Code
import jax.numpy as jnpfrom jaxtyping import Array, Integer, Floatdef my_analysis( recalls: Integer[Array, "trials max_recalls"], list_length: int,**kwargs) -> Float[Array, "..."]:""" Compute my custom analysis. Args: recalls: Recall sequences (1-indexed, 0 = padding) list_length: Number of items in each list **kwargs: Additional parameters Returns: Analysis results """# Your implementation here ...
Example: First Recall Probability
Probability of recalling each serial position first:
Code
def first_recall_probability( recalls: Integer[Array, "trials max_recalls"], list_length: int,) -> Float[Array, "list_length"]:"""Probability each position is recalled first."""# Get first recall for each trial first_recalls = recalls[:, 0]# Count how often each position is first counts = jnp.zeros(list_length)for pos inrange(1, list_length +1): counts = counts.at[pos -1].set(jnp.sum(first_recalls == pos))# Normalizereturn counts / jnp.sum(counts)
JIT Compatibility
For performance, ensure your analysis is JIT-compatible: