MemoEngine#

class torch_ttt.engine.memo_engine.MemoEngine(model: Module, optimization_parameters: Dict[str, Any] = {}, augmentations=None, n_augmentations: int = 8, prior_strength: float = 16)[source]#

MEMO: Test-Time Robustness via Augmentation.

Applies multiple augmentations per test sample and adapts the model by minimizing the entropy of the average prediction across augmentations.

Parameters:
  • model (torch.nn.Module) – The model to adapt.

  • optimization_parameters (dict) – Hyperparameters for adaptation.

  • n_augmentations (int) – Number of augmented views per input sample.

Reference:

“Memo: Test-Time Robustness via Adaptation and Augmentation” Bowen Zhang, Jingfeng Zhang, et al.

ttt_forward(inputs: Tensor) Tuple[Tensor, Tensor][source]#

Performs MEMO test-time forward pass.

Parameters:

inputs (Tensor) – Input tensor.

Returns:

(final prediction logits, adaptation loss)

Return type:

Tuple[Tensor, Tensor]