MemoEngine#
- class torch_ttt.engine.memo_engine.MemoEngine(model: Module, optimization_parameters: Dict[str, Any] = {}, augmentations=None, n_augmentations: int = 8, prior_strength: float = 16)[source]#
MEMO: Test-Time Robustness via Augmentation.
Applies multiple augmentations per test sample and adapts the model by minimizing the entropy of the average prediction across augmentations.
- Parameters:
model (torch.nn.Module) – The model to adapt.
optimization_parameters (dict) – Hyperparameters for adaptation.
n_augmentations (int) – Number of augmented views per input sample.
- Reference:
“Memo: Test-Time Robustness via Adaptation and Augmentation” Bowen Zhang, Jingfeng Zhang, et al.