Code
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxcmr.components.linear_memory import LinearMemory
from jaxcmr.typing import MemoryAssociative memory matrices for CMR
CMR uses two linear associative memories: MFC (item → context) and MCF (context → item). This notebook shows how to create, initialize, and update these memories.
Linear associative memories associate input and output patterns via Hebbian learning:
\[\Delta M = \gamma \cdot \mathbf{x}_{out} \mathbf{x}_{in}^T\]
where \(\gamma\) is the learning rate. Retrieval is a dot product:
\[\mathbf{y} = M \mathbf{x}\]
Item-to-context memory (MFC) maps items to their associated context units. Initially, each item is associated with a unique context unit via strength 1 - learning_rate:
Context-to-item memory (MCF) maps context to item activations. Initially:
Each in-list context unit is associated with all items via shared_support
Each in-list context unit has extra association with one item via item_support
Start-of-list context unit (index 0) has no initial associations
mcf = LinearMemory.init_mcf(item_count, context_size, item_support, shared_support)
plt.figure(figsize=(8, 6))
plt.imshow(mcf.state, cmap='viridis', aspect='auto')
plt.colorbar(label='Association Strength')
plt.xlabel('Item')
plt.ylabel('Context Unit')
plt.title('MCF: Initial State')
plt.tight_layout()
plt.show()Associate a context pattern with an item, then probe:
# Create item and context patterns
items = jnp.eye(item_count)
contexts = jnp.eye(context_size)
# Associate context unit 1 with item 0
mcf_updated = mcf.associate(contexts[1], items[0], learning_rate)
# Probe with context unit 1
activation = mcf_updated.probe(contexts[1])
print(f"Item activations when probing with context unit 1:")
print(f" {activation}")
print(f" Most activated item: {activation.argmax()}")Watch how MCF evolves as items are studied. Each item associates the current context (not shown here, just the item’s context unit for simplicity) with itself:
# Simulate studying items and forming associations
mcf_states = [mcf.state.copy()]
current_mcf = mcf
for i in range(item_count):
# Simplified: associate context unit i+1 with item i
current_mcf = current_mcf.associate(contexts[i + 1], items[i], learning_rate)
mcf_states.append(current_mcf.state.copy())
# Visualize evolution
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
positions = [0, 2, 4, 6, 8, 10]
for ax, pos in zip(axes.flat, positions):
im = ax.imshow(mcf_states[pos], cmap='viridis', aspect='auto', vmin=0, vmax=0.6)
ax.set_title(f'After {pos} items')
ax.set_xlabel('Item')
ax.set_ylabel('Context Unit')
plt.suptitle('MCF Association Matrix Evolution')
plt.tight_layout()
plt.show()Track how the association between a specific context unit and item grows:
# Track association strength for item 0 from context unit 1
association_strengths = [state[1, 0] for state in mcf_states]
plt.figure(figsize=(8, 4))
plt.plot(range(len(association_strengths)), association_strengths, 'o-', linewidth=2, markersize=8)
plt.axhline(y=shared_support + item_support, color='r', linestyle='--', label='Initial (support)')
plt.xlabel('Study Position')
plt.ylabel('Association Strength')
plt.title('MCF[context_1, item_0] Association Growth')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()LinearMemory implements the Memory protocol: