Source code for torch_ttt.engine.memo_engine

import torch
import random
import torchvision.transforms as transforms
import numpy as np
from typing import Dict, Any, Tuple, List
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["MemoEngine"]

[docs] @EngineRegistry.register("memo") class MemoEngine(BaseEngine): """**MEMO**: Test-Time Robustness via Augmentation. Applies multiple augmentations per test sample and adapts the model by minimizing the entropy of the average prediction across augmentations. Args: model (torch.nn.Module): The model to adapt. optimization_parameters (dict): Hyperparameters for adaptation. n_augmentations (int): Number of augmented views per input sample. Reference: "Memo: Test-Time Robustness via Adaptation and Augmentation" Bowen Zhang, Jingfeng Zhang, et al. """ def __init__( self, model: torch.nn.Module, optimization_parameters: Dict[str, Any] = {}, augmentations = None, n_augmentations: int = 8, prior_strength: float = 16 ): super().__init__() self.model = model self.optimization_parameters = optimization_parameters self.augmentations = augmentations self.n_augmentations = n_augmentations self.prior_strength = prior_strength if augmentations is None: self.augmentations = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=10), transforms.RandomAffine( degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), ), AddGaussianNoise(std=0.02, p=0.5), RandomGaussianBlur(kernel_size=3, sigma=(0.1, 1.5), p=0.5), ]) self.model.train() self._configure_bn() def _configure_bn(self): for module in self.model.modules(): # check https://github.com/zhangmarvin/memo/blob/228b2908d271c954ef8bf19cf143ede3b2fa8e3e/imagenet-exps/utils/train_helpers.py#L56C36-L56C95 if isinstance(module, torch.nn.BatchNorm2d): module.prior = float(self.prior_strength) / float(self.prior_strength + 1) module.forward = _modified_bn_forward def _marginal_entropy(self, outputs): logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) min_real = torch.finfo(avg_logits.dtype).min avg_logits = torch.clamp(avg_logits, min=min_real) return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1).mean(), avg_logits
[docs] def ttt_forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Performs MEMO test-time forward pass. Args: inputs (Tensor): Input tensor. Returns: Tuple[Tensor, Tensor]: (final prediction logits, adaptation loss) """ augmented_inputs = [self.augmentations(inputs) for _ in range(self.n_augmentations)] augmented_inputs = torch.stack(augmented_inputs) augmented_inputs = augmented_inputs.view(-1, *augmented_inputs.shape[2:]) logits = self.model(augmented_inputs) logits = logits.view(self.n_augmentations, -1, logits.shape[-1]) loss, logits = self._marginal_entropy(logits) print(loss, loss.shape) return logits, loss
def _modified_bn_forward(self, input): # https://github.com/bethgelab/robustness/blob/main/robusta/batchnorm/bn.py#L175 est_mean = torch.zeros(self.running_mean.shape, device=self.running_mean.device) est_var = torch.ones(self.running_var.shape, device=self.running_var.device) # update est_mean and est_var with the current statistics torch.nn.functional.batch_norm(input, est_mean, est_var, None, None, True, 1.0, self.eps) running_mean = self.prior * self.running_mean + (1 - self.prior) * est_mean running_var = self.prior * self.running_var + (1 - self.prior) * est_var return torch.nn.functional.batch_norm( input, running_mean, running_var, self.weight, self.bias, False, 0, self.eps ) # Optional noise injection class AddGaussianNoise(torch.nn.Module): def __init__(self, std=0.03, p=0.5): super().__init__() self.std = std self.p = p def forward(self, x): if self.training and random.random() < self.p: return x + torch.randn_like(x) * self.std return x # Optional Gaussian blur (for tensors) class RandomGaussianBlur(torch.nn.Module): def __init__(self, kernel_size=5, sigma=(0.1, 2.0), p=0.5): super().__init__() self.kernel_size = kernel_size self.sigma = sigma self.p = p def forward(self, x): if self.training and random.random() < self.p: sigma = random.uniform(*self.sigma) return transforms.functional.gaussian_blur(x, self.kernel_size, sigma) return x