MemoEngine#

class torch_ttt.engine.memo_engine.MemoEngine(model: Module, optimization_parameters: Dict[str, Any] = {}, augmentations=None, n_augmentations: int = 8, prior_strength: float = 16)[source]#

MEMO: Test-Time Robustness via Augmentation.

Applies multiple augmentations per test sample and adapts the model by minimizing the entropy of the average prediction across augmentations.

Parameters:
  • model (torch.nn.Module) – The model to adapt.

  • optimization_parameters (dict) – Hyperparameters for adaptation.

  • n_augmentations (int) – Number of augmented views per input sample.

Example:

from torch_ttt.engine.memo_engine import MemoEngine

model = MyModel()
engine = MemoEngine(model, {"lr": 1e-3}, n_augmentations=4)
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-3)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Inference
engine.eval()
for inputs, labels in test_loader:
    output, loss_ttt = engine(inputs)

Reference:

“Memo: Test-Time Robustness via Adaptation and Augmentation”, Marvin Zhang, Sergey Levine, Chelsea Finn

Paper link: PDF

ttt_forward(inputs: Tensor) Tuple[Tensor, Tensor][source]#

Performs MEMO test-time forward pass.

Parameters:

inputs (Tensor) – Input tensor.

Returns:

(final prediction logits, adaptation loss)

Return type:

Tuple[Tensor, Tensor]