Code
from nbdev.showdoc import show_doc
from jaxcmr.typing import MemorySearch
show_doc(MemorySearch)How models are composed from protocols
jaxcmr models are composed from three core components, each defined by a protocol. This enables swapping implementations without changing model code.
All memory search models implement the MemorySearch protocol:
Models are built from three component types:
Maintains temporal context that drifts during study and retrieval:
Stores and retrieves associations:
Determines when to stop recalling:
A CMR model combines these components:
MemorySearch
├── Context (temporal context state)
├── Memory (MFC: item → context)
├── Memory (MCF: context → item)
└── TerminationPolicy (when to stop)
CMR accepts factory functions for creating each component:
from jaxcmr.models.cmr import CMR
from jaxcmr.components.termination import PositionalTermination, NoStopTermination
params = {
"encoding_drift_rate": 0.5,
"start_drift_rate": 0.5,
"recall_drift_rate": 0.5,
"learning_rate": 0.5,
"primacy_scale": 2.0,
"primacy_decay": 0.8,
"shared_support": 0.05,
"item_support": 0.25,
"choice_sensitivity": 0.6,
"stop_probability_scale": 0.05,
"stop_probability_growth": 0.2,
}
# Standard CMR with positional termination
model_positional = CMR(
list_length=16,
parameters=params,
termination_policy_create_fn=PositionalTermination,
)
# CMR with no-stop termination (recalls until exhausted)
model_no_stop = CMR(
list_length=16,
parameters=params,
termination_policy_create_fn=NoStopTermination,
)
print(f"Positional termination model created")
print(f"No-stop termination model created")Different termination policies produce different stopping behavior:
# Study items
for item_id in range(1, 17):
model_positional = model_positional.experience(item_id)
model_no_stop = model_no_stop.experience(item_id)
# Start retrieval
model_positional = model_positional.start_retrieving()
model_no_stop = model_no_stop.start_retrieving()
# Compare stop probabilities
print("Stop probabilities at output position 0:")
print(f" Positional: {model_positional.termination_policy.stop_probability(model_positional):.3f}")
print(f" No-stop: {model_no_stop.termination_policy.stop_probability(model_no_stop):.3f}")
# Simulate a few recalls and check how stop probability changes
for i in range(3):
model_positional = model_positional.retrieve(16 - i)
model_no_stop = model_no_stop.retrieve(16 - i)
print(f"\nAfter 3 recalls:")
print(f" Positional: {model_positional.termination_policy.stop_probability(model_positional):.3f}")
print(f" No-stop: {model_no_stop.termination_policy.stop_probability(model_no_stop):.3f}")jaxcmr provides these built-in implementations:
| Component | Implementations |
|---|---|
| Context | TemporalContext |
| Memory | LinearMemory (MFC, MCF variants) |
| Termination | PositionalTermination, NoStopTermination, SupportRatioTermination, RetrievalDependentTermination |
Model variants in jaxcmr.models often customize one or more of these components.