TTTPPEngine#
- class torch_ttt.engine.ttt_pp_engine.TTTPPEngine(model: Module, features_layer_name: str, contrastive_head: Module | None = None, contrastive_criterion: Module = ContrastiveLoss(), contrastive_transform: Callable | None = None, scale_cov: float = 0.1, scale_mu: float = 0.1, scale_c_cov: float = 0.1, scale_c_mu: float = 0.1, optimization_parameters: Dict[str, Any] = {})[source]#
TTT++ approach: feature alignment-based + SimCLR loss.
- Parameters:
model (torch.nn.Module) – Model to be trained with TTT.
features_layer_name (str) – The name of the layer from which the features are extracted.
contrastive_head (torch.nn.Module, optional) – The head that is used for SimCLR’s Loss.
contrastive_criterion (torch.nn.Module, optional) – The loss function used for SimCLR.
contrastive_transform (callable) – A transformation or a composition of transformations applied to the input images to generate augmented views for contrastive learning.
scale_cov (float) – The scale factor for the covariance loss.
scale_mu (float) – The scale factor for the mean loss.
scale_c_cov (float) – The scale factor for the contrastive covariance loss.
scale_c_mu (float) – The scale factor for the contrastive mean loss.
optimization_parameters (dict) – The optimization parameters for the engine.
Warning
The module with the name
features_layer_name
should be present in the model.- Example:
from torch_ttt.engine.ttt_pp_engine import TTTPPEngine model = MyModel() engine = TTTPPEngine(model, "fc1") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute statistics for features alignment engine.compute_statistics(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“TTT++: When Does Self-Supervised Test-Time Training Fail or Thrive?”, Yuejiang Liu, Parth Kothari, Bastien van Delft, Baptiste Bellot-Gurlet, Taylor Mordan, Alexandre Alahi
Paper link: PDF
- compute_statistics(dataloader: DataLoader) None [source]#
Extract and compute reference statistics for features and contrastive features.
- Parameters:
dataloader (DataLoader) – The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
- Raises:
ValueError – If the dataloader is empty or features have mismatched dimensions.