Source code for torch_ttt.engine.actmad_engine

import torch
from typing import List, Dict, Any, Tuple, Union
from contextlib import contextmanager

from torch.utils.data import DataLoader
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["ActMADEngine"]


# TODO: add cuda support
[docs] @EngineRegistry.register("actmad") class ActMADEngine(BaseEngine): """**ActMAD** approach: multi-level pixel-wise feature alignment. Args: model (torch.nn.Module): Model to be trained with TTT. features_layer_names (List[str] | str): List of layer names to be used for feature alignment. optimization_parameters (dict): The optimization parameters for the engine. :Example: .. code-block:: python from torch_ttt.engine.actmad_engine import ActMADEngine model = MyModel() engine = ActMADEngine(model, ["fc1", "fc2"]) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute statistics for features alignment engine.compute_statistics(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs) Reference: "ActMAD: Activation Matching to Align Distributions for Test-Time Training", M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_ """ def __init__( self, model: torch.nn.Module, features_layer_names: Union[List[str], str], optimization_parameters: Dict[str, Any] = {}, ): super().__init__() self.model = model self.features_layer_names = features_layer_names self.optimization_parameters = optimization_parameters if isinstance(features_layer_names, str): self.features_layer_names = [features_layer_names] # TODO: rewrite this self.target_modules = [] for layer_name in self.features_layer_names: layer_exists = False for name, module in model.named_modules(): if name == layer_name: layer_exists = True self.target_modules.append(module) break if not layer_exists: raise ValueError(f"Layer {layer_name} does not exist in the model.") self.reference_mean = None self.reference_var = None
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the model. Args: inputs (torch.Tensor): Input tensor. Returns: Returns the current model prediction and rotation loss value. """ with self.__capture_hook() as features_hooks: outputs = self.model(inputs) features = [hook.output for hook in features_hooks] # don't compute loss during training if self.training: return outputs, 0 if self.reference_var is None or self.reference_mean is None: raise ValueError( "Reference statistics are not computed. Please call `compute_statistics` method." ) l1_loss = torch.nn.L1Loss(reduction='mean') features_means = [torch.mean(feature, dim=0) for feature in features] features_vars = [torch.var(feature, dim=0) for feature in features] loss = 0 for i in range(len(self.target_modules)): print(features_means[i].device, self.reference_mean[i].device) loss += l1_loss(features_means[i], self.reference_mean[i]) loss += l1_loss(features_vars[i], self.reference_var[i]) return outputs, loss
[docs] def compute_statistics(self, dataloader: DataLoader) -> None: """Extract and compute reference statistics for features. Args: dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor. Raises: ValueError: If the dataloader is empty or features have mismatched dimensions. """ self.model.eval() feat_stack = [[] for _ in self.target_modules] # TODO: compute variance in more memory efficient way with torch.no_grad(): device = next(self.model.parameters()).device for sample in dataloader: if len(sample) < 1: raise ValueError("Dataloader returned an empty batch.") inputs = sample[0].to(device) with self.__capture_hook() as features_hooks: _ = self.model(inputs) features = [hook.output.cpu() for hook in features_hooks] for i, feature in enumerate(features): feat_stack[i].append(feature) # Compute mean and variance self.reference_mean = [torch.mean(torch.cat(feat), dim=0).to(device) for feat in feat_stack] self.reference_var = [torch.var(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
@contextmanager def __capture_hook(self): """Context manager to capture features via a forward hook.""" class OutputHook: def __init__(self): self.output = None def hook(self, module, input, output): self.output = output hook_handels = [] features_hooks = [] for module in self.target_modules: hook = OutputHook() features_hooks.append(hook) hook_handle = module.register_forward_hook(hook.hook) hook_handels.append(hook_handle) try: yield features_hooks finally: for hook in hook_handels: hook.remove()