TTTEngine#
- class torch_ttt.engine.ttt_engine.TTTEngine(model: Module, features_layer_name: str, angle_head: Module | None = None, angle_criterion: Module | None = None, optimization_parameters: Dict[str, Any] = {})[source]#
Original image rotation-based test-time training approach.
- Parameters:
model (torch.nn.Module) – Model to be trained with TTT.
features_layer_name (str) – The name of the layer from which the features are extracted.
angle_head (torch.nn.Module, optional) – The head that predicts the rotation angles.
angle_criterion (torch.nn.Module, optional) – The loss function for the rotation angles.
optimization_parameters (dict) – The optimization parameters for the engine.
Warning
The module with the name
features_layer_name
should be present in the model.Note
angle_head
andangle_criterion
are optional arguments and can be user-defined. If not provided, the default shallow head and thetorch.nn.CrossEntropyLoss()
loss function are used.Note
The original TTT implementation uses a four-class classification task, corresponding to image rotations of 0°, 90°, 180°, and 270°.
- Example:
from torch_ttt.engine.ttt_engine import TTTEngine model = MyModel() engine = TTTEngine(model, "fc1") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“Test-Time Training with Self-Supervision for Generalization under Distribution Shifts”, Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt
Paper link: PDF