TTTEngine#

class torch_ttt.engine.ttt_engine.TTTEngine(model: Module, features_layer_name: str, angle_head: Module | None = None, angle_criterion: Module | None = None, optimization_parameters: Dict[str, Any] = {})[source]#

Original image rotation-based test-time training approach.

Parameters:
  • model (torch.nn.Module) – Model to be trained with TTT.

  • features_layer_name (str) – The name of the layer from which the features are extracted.

  • angle_head (torch.nn.Module, optional) – The head that predicts the rotation angles.

  • angle_criterion (torch.nn.Module, optional) – The loss function for the rotation angles.

  • optimization_parameters (dict) – The optimization parameters for the engine.

Warning

The module with the name features_layer_name should be present in the model.

Note

angle_head and angle_criterion are optional arguments and can be user-defined. If not provided, the default shallow head and the torch.nn.CrossEntropyLoss() loss function are used.

Note

The original TTT implementation uses a four-class classification task, corresponding to image rotations of 0°, 90°, 180°, and 270°.

Example:

from torch_ttt.engine.ttt_engine import TTTEngine

model = MyModel()
engine = TTTEngine(model, "fc1")
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Inference
engine.eval()
for inputs, labels in test_loader:
    output, loss_ttt = engine(inputs)

Reference:

“Test-Time Training with Self-Supervision for Generalization under Distribution Shifts”, Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt

Paper link: PDF

ttt_forward(inputs) Tuple[Tensor, Tensor][source]#

Forward pass of the model.

Parameters:

inputs (torch.Tensor) – Input tensor.

Returns:

Returns the current model prediction and rotation loss value.