ActMADEngine#
- class torch_ttt.engine.actmad_engine.ActMADEngine(model: Module, features_layer_names: List[str] | str, optimization_parameters: Dict[str, Any] = {})[source]#
ActMAD approach: multi-level pixel-wise feature alignment.
- Parameters:
model (torch.nn.Module) – Model to be trained with TTT.
features_layer_names (List[str] | str) – List of layer names to be used for feature alignment.
optimization_parameters (dict) – The optimization parameters for the engine.
- Example:
from torch_ttt.engine.actmad_engine import ActMADEngine model = MyModel() engine = ActMADEngine(model, ["fc1", "fc2"]) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute statistics for features alignment engine.compute_statistics(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“ActMAD: Activation Matching to Align Distributions for Test-Time Training”, M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof
Paper link: PDF
- compute_statistics(dataloader: DataLoader) None [source]#
Extract and compute reference statistics for features.
- Parameters:
dataloader (DataLoader) – The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
- Raises:
ValueError – If the dataloader is empty or features have mismatched dimensions.