Source code for torch_ttt.engine.tent_engine

import torch
from typing import Dict, Any, Tuple
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["TentEngine"]

[docs] @EngineRegistry.register("tent") class TentEngine(BaseEngine): """**TENT**: Fully test-time adaptation by entropy minimization. TENT adapts models at inference by minimizing prediction entropy, encouraging confident outputs on unlabeled data. It updates only BatchNorm affine parameters and requires no labels or training supervision. Args: model (torch.nn.Module): Model to be adapted at test-time. optimization_parameters (dict): Optimizer configuration for adaptation (e.g. learning rate). :Example: .. code-block:: python from torch_ttt.engine.tent_engine import TentEngine model = MyModel() engine = TentEngine(model, {"lr": 1e-3}) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-3) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs) Reference: "Tent: Fully Test-Time Adaptation by Entropy Minimization", Dequan Wang, Evan Shelhamer, Shaoteng Liu, Bruno Vasconcelos, Trevor Darrell Paper link: `PDF <https://arxiv.org/pdf/2006.10726.pdf>`_ """ def __init__( self, model: torch.nn.Module, optimization_parameters: Dict[str, Any] = {}, ): super().__init__() self.model = model self.optimization_parameters = optimization_parameters # Tent adapts only affine parameters in BatchNorm self.model.train() self._configure_bn() def _configure_bn(self): for module in self.model.modules(): if isinstance(module, torch.nn.BatchNorm2d): module.requires_grad_(True) module.track_running_stats = False else: for param in module.parameters(recurse=False): param.requires_grad = False
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the model. Args: inputs (torch.Tensor): Input tensor. Returns: The current model prediction and the entropy loss value. """ outputs = self.model(inputs) probs = torch.nn.functional.softmax(outputs, dim=1) log_probs = torch.nn.functional.log_softmax(outputs, dim=1) entropy = -torch.sum(probs * log_probs, dim=1).mean() return outputs, entropy