Source code for torch_ttt.engine.tent_engine

import torch
from typing import Dict, Any, Tuple
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["TentEngine"]

[docs] @EngineRegistry.register("tent") class TentEngine(BaseEngine): """**TENT**: Fully test-time adaptation by entropy minimization. Args: model (torch.nn.Module): The model to adapt. optimization_parameters (dict): Hyperparameters for adaptation. Reference: "TENT: Fully Test-Time Adaptation by Entropy Minimization" Dequan Wang, Evan Shelhamer, et al. """ def __init__( self, model: torch.nn.Module, optimization_parameters: Dict[str, Any] = {}, ): super().__init__() self.model = model self.optimization_parameters = optimization_parameters # Tent adapts only affine parameters in BatchNorm self.model.train() self._configure_bn() def _configure_bn(self): for module in self.model.modules(): if isinstance(module, torch.nn.BatchNorm2d): module.requires_grad_(True) module.track_running_stats = False else: for param in module.parameters(recurse=False): param.requires_grad = False
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass and entropy loss computation.""" outputs = self.model(inputs) probs = torch.nn.functional.softmax(outputs, dim=1) log_probs = torch.nn.functional.log_softmax(outputs, dim=1) entropy = -torch.sum(probs * log_probs, dim=1).mean() return outputs, entropy