TTTPPEngine#

class torch_ttt.engine.ttt_pp_engine.TTTPPEngine(model: Module, features_layer_name: str, contrastive_head: Module | None = None, contrastive_criterion: Module = ContrastiveLoss(), contrastive_transform: Callable | None = None, scale_cov: float = 0.1, scale_mu: float = 0.1, scale_c_cov: float = 0.1, scale_c_mu: float = 0.1, optimization_parameters: Dict[str, Any] = {})[source]#

TTT++ approach: feature alignment-based + SimCLR loss.

Parameters:
  • model (torch.nn.Module) – Model to be trained with TTT.

  • features_layer_name (str) – The name of the layer from which the features are extracted.

  • contrastive_head (torch.nn.Module, optional) – The head that is used for SimCLR’s Loss.

  • contrastive_criterion (torch.nn.Module, optional) – The loss function used for SimCLR.

  • contrastive_transform (callable) – A transformation or a composition of transformations applied to the input images to generate augmented views for contrastive learning.

  • scale_cov (float) – The scale factor for the covariance loss.

  • scale_mu (float) – The scale factor for the mean loss.

  • scale_c_cov (float) – The scale factor for the contrastive covariance loss.

  • scale_c_mu (float) – The scale factor for the contrastive mean loss.

  • optimization_parameters (dict) – The optimization parameters for the engine.

Warning

The module with the name features_layer_name should be present in the model.

Example:

from torch_ttt.engine.ttt_pp_engine import TTTPPEngine

model = MyModel()
engine = TTTPPEngine(model, "fc1")
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Compute statistics for features alignment
engine.compute_statistics(train_loader)

# Inference
engine.eval()
for inputs, labels in test_loader:
    output, loss_ttt = engine(inputs)

Reference:

“TTT++: When Does Self-Supervised Test-Time Training Fail or Thrive?”, Yuejiang Liu, Parth Kothari, Bastien van Delft, Baptiste Bellot-Gurlet, Taylor Mordan, Alexandre Alahi

Paper link: PDF

compute_statistics(dataloader: DataLoader) None[source]#

Extract and compute reference statistics for features and contrastive features.

Parameters:

dataloader (DataLoader) – The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.

Raises:

ValueError – If the dataloader is empty or features have mismatched dimensions.

ttt_forward(inputs) Tuple[Tensor, Tensor][source]#

Forward pass of the model.

Parameters:

inputs (torch.Tensor) – Input tensor.

Returns:

Returns the model prediction and TTT++ loss value.