Source code for torch_ttt.engine.ttt_engine

import torch
from contextlib import contextmanager
from typing import Tuple, Dict, Any
from torchvision.transforms import functional as F
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["TTTEngine"]


[docs] @EngineRegistry.register("ttt") class TTTEngine(BaseEngine): r"""Original image rotation-based **test-time training** approach. Args: model (torch.nn.Module): Model to be trained with TTT. features_layer_name (str): The name of the layer from which the features are extracted. angle_head (torch.nn.Module, optional): The head that predicts the rotation angles. angle_criterion (torch.nn.Module, optional): The loss function for the rotation angles. optimization_parameters (dict): The optimization parameters for the engine. Warning: The module with the name :attr:`features_layer_name` should be present in the model. Note: :attr:`angle_head` and :attr:`angle_criterion` are optional arguments and can be user-defined. If not provided, the default shallow head and the :meth:`torch.nn.CrossEntropyLoss()` loss function are used. Note: The original `TTT <https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/rotation.py#L27>`_ implementation uses a four-class classification task, corresponding to image rotations of 0°, 90°, 180°, and 270°. :Example: .. code-block:: python from torch_ttt.engine.ttt_engine import TTTEngine model = MyModel() engine = TTTEngine(model, "fc1") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs) Reference: "Test-Time Training with Self-Supervision for Generalization under Distribution Shifts", Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt Paper link: `PDF <http://proceedings.mlr.press/v119/sun20b/sun20b.pdf>`_ """ def __init__( self, model: torch.nn.Module, features_layer_name: str, angle_head: torch.nn.Module = None, angle_criterion: torch.nn.Module = None, optimization_parameters: Dict[str, Any] = {}, ) -> None: super().__init__() self.model = model self.angle_head = angle_head self.angle_criterion = angle_criterion if angle_criterion else torch.nn.CrossEntropyLoss() self.features_layer_name = features_layer_name self.optimization_parameters = optimization_parameters # Locate and store the reference to the target module self.target_module = None for name, module in model.named_modules(): if name == features_layer_name: self.target_module = module break if self.target_module is None: raise ValueError(f"Module '{features_layer_name}' not found in the model.")
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the model. Args: inputs (torch.Tensor): Input tensor. Returns: Returns the current model prediction and rotation loss value. """ # has to dynamically register a hook to get the features and then remove it # need this for deepcopying the engine, see https://github.com/pytorch/pytorch/pull/103001 with self.__capture_hook() as features_hook: # Original forward pass, intact outputs = self.model(inputs) # See original code: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/main.py#L69 rotated_inputs, rotation_labels = self.__rotate_inputs(inputs) _ = self.model(rotated_inputs) features = features_hook.output # Build angle head if not already built if self.angle_head is None: self.angle_head = self.__build_angle_head(features) # move angle head to the same device as the features self.angle_head.to(features.device) angles = self.angle_head(features) # Compute rotation loss rotation_loss = self.angle_criterion(angles, rotation_labels) return outputs, rotation_loss
# Follow this code (expand case): https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/rotation.py#L27 def __rotate_inputs(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Rotate the input images by 0, 90, 180, and 270 degrees.""" device = next(self.model.parameters()).device rotated_image_90 = F.rotate(inputs, 90) rotated_image_180 = F.rotate(inputs, 180) rotated_image_270 = F.rotate(inputs, 270) batch_size = inputs.shape[0] inputs = torch.cat([inputs, rotated_image_90, rotated_image_180, rotated_image_270], dim=0) labels = [0] * batch_size + [1] * batch_size + [2] * batch_size + [3] * batch_size return inputs.to(device), torch.tensor(labels, dtype=torch.long).to(device) def __build_angle_head(self, features) -> torch.nn.Module: """Build the angle head.""" device = next(self.model.parameters()).device # See original implementation: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/test_helpers.py#L33C10-L33C39 if len(features.shape) == 2: return torch.nn.Sequential( torch.nn.Linear(features.shape[1], 16), torch.nn.ReLU(), torch.nn.Linear(16, 8), torch.nn.ReLU(), torch.nn.Linear(8, 4), ).to(device) # See original implementation: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/models/SSHead.py#L29 elif len(features.shape) == 4: return torch.nn.Sequential( torch.nn.Conv2d(features.shape[1], 16, 3), torch.nn.ReLU(), torch.nn.Conv2d(16, 4, 3), torch.nn.AdaptiveAvgPool2d((1, 1)), # Global Average Pooling torch.nn.Flatten(), ).to(device) elif len(features.shape) == 5: # For 3D inputs (batch, channels, depth, height, width) return torch.nn.Sequential( torch.nn.Conv3d(features.shape[1], 16, kernel_size=3), torch.nn.ReLU(), torch.nn.Conv3d(16, 4, kernel_size=3), torch.nn.AdaptiveAvgPool3d((1, 1, 1)), # Global Average Pooling torch.nn.Flatten(), ).to(device) raise ValueError("Invalid input tensor shape.") @contextmanager def __capture_hook(self): """Context manager to capture features via a forward hook.""" class OutputHook: def __init__(self): self.output = None def hook(self, module, input, output): self.output = output features_hook = OutputHook() hook_handle = self.target_module.register_forward_hook(features_hook.hook) try: yield features_hook finally: hook_handle.remove()