Linear Memory

Associative memory matrices for CMR

CMR uses two linear associative memories: MFC (item → context) and MCF (context → item). This notebook shows how to create, initialize, and update these memories.

Learning Rule

Linear associative memories associate input and output patterns via Hebbian learning:

\[\Delta M = \gamma \cdot \mathbf{x}_{out} \mathbf{x}_{in}^T\]

where \(\gamma\) is the learning rate. Retrieval is a dot product:

\[\mathbf{y} = M \mathbf{x}\]

Setup

Code
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from jaxcmr.components.linear_memory import LinearMemory
from jaxcmr.typing import Memory

Parameters

Code
item_count = 10
context_size = item_count + 1
learning_rate = 0.5
item_support = 0.1
shared_support = 0.05

MFC Initialization

Item-to-context memory (MFC) maps items to their associated context units. Initially, each item is associated with a unique context unit via strength 1 - learning_rate:

Code
mfc = LinearMemory.init_mfc(item_count, context_size, learning_rate)

plt.figure(figsize=(8, 6))
plt.imshow(mfc.state, cmap='viridis', aspect='auto')
plt.colorbar(label='Association Strength')
plt.xlabel('Context Unit')
plt.ylabel('Item')
plt.title('MFC: Initial State')
plt.tight_layout()
plt.show()

MCF Initialization

Context-to-item memory (MCF) maps context to item activations. Initially:

  • Each in-list context unit is associated with all items via shared_support

  • Each in-list context unit has extra association with one item via item_support

  • Start-of-list context unit (index 0) has no initial associations

Code
mcf = LinearMemory.init_mcf(item_count, context_size, item_support, shared_support)

plt.figure(figsize=(8, 6))
plt.imshow(mcf.state, cmap='viridis', aspect='auto')
plt.colorbar(label='Association Strength')
plt.xlabel('Item')
plt.ylabel('Context Unit')
plt.title('MCF: Initial State')
plt.tight_layout()
plt.show()

Association and Retrieval

Associate a context pattern with an item, then probe:

Code
# Create item and context patterns
items = jnp.eye(item_count)
contexts = jnp.eye(context_size)

# Associate context unit 1 with item 0
mcf_updated = mcf.associate(contexts[1], items[0], learning_rate)

# Probe with context unit 1
activation = mcf_updated.probe(contexts[1])

print(f"Item activations when probing with context unit 1:")
print(f"  {activation}")
print(f"  Most activated item: {activation.argmax()}")

Association Matrix Evolution During Study

Watch how MCF evolves as items are studied. Each item associates the current context (not shown here, just the item’s context unit for simplicity) with itself:

Code
# Simulate studying items and forming associations
mcf_states = [mcf.state.copy()]
current_mcf = mcf

for i in range(item_count):
    # Simplified: associate context unit i+1 with item i
    current_mcf = current_mcf.associate(contexts[i + 1], items[i], learning_rate)
    mcf_states.append(current_mcf.state.copy())

# Visualize evolution
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
positions = [0, 2, 4, 6, 8, 10]

for ax, pos in zip(axes.flat, positions):
    im = ax.imshow(mcf_states[pos], cmap='viridis', aspect='auto', vmin=0, vmax=0.6)
    ax.set_title(f'After {pos} items')
    ax.set_xlabel('Item')
    ax.set_ylabel('Context Unit')

plt.suptitle('MCF Association Matrix Evolution')
plt.tight_layout()
plt.show()

Association Growth

Track how the association between a specific context unit and item grows:

Code
# Track association strength for item 0 from context unit 1
association_strengths = [state[1, 0] for state in mcf_states]

plt.figure(figsize=(8, 4))
plt.plot(range(len(association_strengths)), association_strengths, 'o-', linewidth=2, markersize=8)
plt.axhline(y=shared_support + item_support, color='r', linestyle='--', label='Initial (support)')
plt.xlabel('Study Position')
plt.ylabel('Association Strength')
plt.title('MCF[context_1, item_0] Association Growth')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Protocol Compliance

LinearMemory implements the Memory protocol:

Code
from nbdev.showdoc import show_doc

show_doc(Memory)