Source code for torch_ttt.engine.ttt_pp_engine

import torch
from typing import Tuple, Optional, Callable, Dict, Any
from contextlib import contextmanager

from torchvision import transforms
from torch.utils.data import DataLoader
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry
from torch_ttt.loss.contrastive_loss import ContrastiveLoss
from torch_ttt.utils.augmentations import RandomResizedCrop

__all__ = ["TTTPPEngine"]


# TODO: finish this class
[docs] @EngineRegistry.register("ttt_pp") class TTTPPEngine(BaseEngine): """**TTT++** approach: feature alignment-based + SimCLR loss. Args: model (torch.nn.Module): Model to be trained with TTT. features_layer_name (str): The name of the layer from which the features are extracted. contrastive_head (torch.nn.Module, optional): The head that is used for SimCLR's Loss. contrastive_criterion (torch.nn.Module, optional): The loss function used for SimCLR. contrastive_transform (callable): A transformation or a composition of transformations applied to the input images to generate augmented views for contrastive learning. scale_cov (float): The scale factor for the covariance loss. scale_mu (float): The scale factor for the mean loss. scale_c_cov (float): The scale factor for the contrastive covariance loss. scale_c_mu (float): The scale factor for the contrastive mean loss. optimization_parameters (dict): The optimization parameters for the engine. Warning: The module with the name :attr:`features_layer_name` should be present in the model. :Example: .. code-block:: python from torch_ttt.engine.ttt_pp_engine import TTTPPEngine model = MyModel() engine = TTTPPEngine(model, "fc1") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute statistics for features alignment engine.compute_statistics(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs) Reference: "TTT++: When Does Self-Supervised Test-Time Training Fail or Thrive?", Yuejiang Liu, Parth Kothari, Bastien van Delft, Baptiste Bellot-Gurlet, Taylor Mordan, Alexandre Alahi Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_ """ def __init__( self, model: torch.nn.Module, features_layer_name: str, contrastive_head: torch.nn.Module = None, contrastive_criterion: torch.nn.Module = ContrastiveLoss(), contrastive_transform: Optional[Callable] = None, scale_cov: float = 0.1, scale_mu: float = 0.1, scale_c_cov: float = 0.1, scale_c_mu: float = 0.1, optimization_parameters: Dict[str, Any] = {}, ) -> None: super().__init__() self.model = model self.features_layer_name = features_layer_name self.contrastive_head = contrastive_head self.contrastive_criterion = ( contrastive_criterion if contrastive_criterion else ContrastiveLoss() ) self.scale_cov = scale_cov self.scale_mu = scale_mu self.scale_c_cov = scale_c_cov self.scale_c_mu = scale_c_mu self.contrastive_transform = contrastive_transform self.reference_cov = None self.reference_mean = None self.reference_c_cov = None self.reference_c_mean = None self.optimization_parameters = optimization_parameters # Locate and store the reference to the target module self.target_module = None for name, module in model.named_modules(): if name == features_layer_name: self.target_module = module break if self.target_module is None: raise ValueError(f"Module '{features_layer_name}' not found in the model.") # Validate that the target module is a Linear layer if not isinstance(self.target_module, torch.nn.Linear): raise TypeError( f"Module '{features_layer_name}' is expected to be of type 'torch.nn.Linear', " f"but found type '{type(self.target_module).__name__}'." ) if contrastive_transform is None: # default SimCLR augmentation self.contrastive_transform = transforms.Compose( [ RandomResizedCrop(scale=(0.2, 1.0)), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([transforms.GaussianBlur(5)], p=0.3), transforms.RandomHorizontalFlip(), ] ) def __build_contrastive_head(self, features) -> torch.nn.Module: """Build the angle head.""" device = next(self.model.parameters()).device if len(features.shape) == 2: return torch.nn.Sequential( torch.nn.Linear(features.shape[1], 16), torch.nn.ReLU(), torch.nn.Linear(16, 16), torch.nn.ReLU(), torch.nn.Linear(16, 16), ).to(device) raise ValueError("Features should be 2D tensor.")
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the model. Args: inputs (torch.Tensor): Input tensor. Returns: Returns the model prediction and TTT++ loss value. """ # reset reference statistics during training if self.training: self.reference_cov = None self.reference_mean = None self.reference_c_cov = None self.reference_c_mean = None contrastive_inputs = torch.cat( [self.contrastive_transform(inputs), self.contrastive_transform(inputs)], dim=0 ) # extract features for contrastive loss with self.__capture_hook() as features_hook: _ = self.model(contrastive_inputs) features = features_hook.output # Build angle head if not already built if self.contrastive_head is None: self.contrastive_head = self.__build_contrastive_head(features) contrasitve_features = self.contrastive_head(features) contrasitve_features = contrasitve_features.view(2, len(inputs), -1).transpose(0, 1) loss = self.contrastive_criterion(contrasitve_features) # make inference for a final prediction with self.__capture_hook() as features_hook: outputs = self.model(inputs) features = features_hook.output # compute alignment loss only during test if not self.training: if ( self.reference_cov is None or self.reference_mean is None or self.reference_c_cov is None or self.reference_c_mean is None ): raise ValueError( "Reference statistics are not computed. Please call `compute_statistics` method." ) # compute features alignment loss cov_ext = self.__covariance(features) mu_ext = features.mean(dim=0) d = self.reference_cov.shape[0] loss += self.scale_cov * (self.reference_cov - cov_ext).pow(2).sum() / (4.0 * d**2) loss += self.scale_mu * (self.reference_mean - mu_ext).pow(2).mean() # compute contrastive features alignment loss c_features = self.contrastive_head(features) cov_ext = self.__covariance(c_features) mu_ext = c_features.mean(dim=0) d = self.reference_c_cov.shape[0] loss += self.scale_c_cov * (self.reference_c_cov - cov_ext).pow(2).sum() / (4.0 * d**2) loss += self.scale_c_mu * (self.reference_c_mean - mu_ext).pow(2).mean() return outputs, loss
@staticmethod def __covariance(features): """Legacy wrapper to maintain compatibility in the engine.""" from torch_ttt.utils.math import compute_covariance return compute_covariance(features, dim=0)
[docs] def compute_statistics(self, dataloader: DataLoader) -> None: """Extract and compute reference statistics for features and contrastive features. Args: dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor. Raises: ValueError: If the dataloader is empty or features have mismatched dimensions. """ self.model.eval() feat_stack = [] c_feat_stack = [] with torch.no_grad(): device = next(self.model.parameters()).device for sample in dataloader: if len(sample) < 1: raise ValueError("Dataloader returned an empty batch.") inputs = sample[0].to(device) with self.__capture_hook() as features_hook: _ = self.model(inputs) feat = features_hook.output # Initialize contrastive head if not already initialized if self.contrastive_head is None: self.contrastive_head = self.__build_contrastive_head(feat) # Compute contrastive features contrastive_feat = self.contrastive_head(feat) feat_stack.append(feat.cpu()) c_feat_stack.append(contrastive_feat.cpu()) # compute features statistics feat_all = torch.cat(feat_stack) feat_cov = self.__covariance(feat_all) feat_mean = feat_all.mean(dim=0) self.reference_cov = feat_cov.to(device) self.reference_mean = feat_mean.to(device) # compute contrastive features statistics feat_all = torch.cat(c_feat_stack) feat_cov = self.__covariance(feat_all) feat_mean = feat_all.mean(dim=0) self.reference_c_cov = feat_cov.to(device) self.reference_c_mean = feat_mean.to(device)
@contextmanager def __capture_hook(self): """Context manager to capture features via a forward hook.""" class OutputHook: def __init__(self): self.output = None def hook(self, module, input, output): self.output = output features_hook = OutputHook() hook_handle = self.target_module.register_forward_hook(features_hook.hook) try: yield features_hook finally: hook_handle.remove()