MemoEngine#
- class torch_ttt.engine.memo_engine.MemoEngine(model: Module, optimization_parameters: Dict[str, Any] = {}, augmentations=None, n_augmentations: int = 8, prior_strength: float = 16)[source]#
MEMO: Test-Time Robustness via Augmentation.
Applies multiple augmentations per test sample and adapts the model by minimizing the entropy of the average prediction across augmentations.
- Parameters:
model (torch.nn.Module) – The model to adapt.
optimization_parameters (dict) – Hyperparameters for adaptation.
n_augmentations (int) – Number of augmented views per input sample.
- Example:
from torch_ttt.engine.memo_engine import MemoEngine model = MyModel() engine = MemoEngine(model, {"lr": 1e-3}, n_augmentations=4) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-3) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“Memo: Test-Time Robustness via Adaptation and Augmentation”, Marvin Zhang, Sergey Levine, Chelsea Finn
Paper link: PDF