Source code for torch_ttt.engine.actmad_engine

import torch
from typing import List, Dict, Any, Tuple, Union
from contextlib import contextmanager

from torch.utils.data import DataLoader
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["ActMADEngine"]


[docs] @EngineRegistry.register("actmad") class ActMADEngine(BaseEngine): """**ActMAD** approach: multi-level pixel-wise feature alignment. ActMAD adapts models at test-time by aligning activation statistics (means and variances) of the test inputs to those from clean training data, across multiple layers of the network. It requires no labels or auxiliary tasks, and is applicable to any architecture and task. Args: model (torch.nn.Module): Model to be trained with TTT. features_layer_names (List[str] | str): List of layer names to be used for feature alignment. optimization_parameters (dict): The optimization parameters for the engine. :Example: .. code-block:: python from torch_ttt.engine.actmad_engine import ActMADEngine model = MyModel() engine = ActMADEngine(model, ["fc1", "fc2"]) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute statistics for features alignment engine.compute_statistics(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs) Reference: "ActMAD: Activation Matching to Align Distributions for Test-Time Training", M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_ """ def __init__( self, model: torch.nn.Module, features_layer_names: Union[List[str], str], optimization_parameters: Dict[str, Any] = {}, ): super().__init__() self.model = model self.features_layer_names = features_layer_names self.optimization_parameters = optimization_parameters if isinstance(features_layer_names, str): self.features_layer_names = [features_layer_names] # TODO: rewrite this self.target_modules = [] for layer_name in self.features_layer_names: layer_exists = False for name, module in model.named_modules(): if name == layer_name: layer_exists = True self.target_modules.append(module) break if not layer_exists: raise ValueError(f"Layer {layer_name} does not exist in the model.") self.reference_mean = None self.reference_var = None
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the model. Args: inputs (torch.Tensor): Input tensor. Returns: The current model prediction and the alignment loss based on activation statistics. """ with self.__capture_hook() as features_hooks: outputs = self.model(inputs) features = [hook.output for hook in features_hooks] # don't compute loss during training if self.training: return outputs, 0 if self.reference_var is None or self.reference_mean is None: raise ValueError( "Reference statistics are not computed. Please call `compute_statistics` method." ) l1_loss = torch.nn.L1Loss(reduction='mean') features_means = [torch.mean(feature, dim=0) for feature in features] features_vars = [torch.var(feature, dim=0) for feature in features] loss = 0 for i in range(len(self.target_modules)): print(features_means[i].device, self.reference_mean[i].device) loss += l1_loss(features_means[i], self.reference_mean[i]) loss += l1_loss(features_vars[i], self.reference_var[i]) return outputs, loss
[docs] def compute_statistics(self, dataloader: DataLoader) -> None: """Extract and compute reference statistics for features. Args: dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor. Raises: ValueError: If the dataloader is empty or features have mismatched dimensions. """ self.model.eval() feat_stack = [[] for _ in self.target_modules] # TODO: compute variance in more memory efficient way with torch.no_grad(): device = next(self.model.parameters()).device for sample in dataloader: if len(sample) < 1: raise ValueError("Dataloader returned an empty batch.") inputs = sample[0].to(device) with self.__capture_hook() as features_hooks: _ = self.model(inputs) features = [hook.output.cpu() for hook in features_hooks] for i, feature in enumerate(features): feat_stack[i].append(feature) # Compute mean and variance self.reference_mean = [torch.mean(torch.cat(feat), dim=0).to(device) for feat in feat_stack] self.reference_var = [torch.var(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
@contextmanager def __capture_hook(self): """Context manager to capture features via a forward hook.""" class OutputHook: def __init__(self): self.output = None def hook(self, module, input, output): self.output = output hook_handels = [] features_hooks = [] for module in self.target_modules: hook = OutputHook() features_hooks.append(hook) hook_handle = module.register_forward_hook(hook.hook) hook_handels.append(hook_handle) try: yield features_hooks finally: for hook in hook_handels: hook.remove()