ActMADEngine#
- class torch_ttt.engine.actmad_engine.ActMADEngine(model: Module, features_layer_names: List[str] | str, optimization_parameters: Dict[str, Any] = {})[source]#
ActMAD approach: multi-level pixel-wise feature alignment.
ActMAD adapts models at test-time by aligning activation statistics (means and variances) of the test inputs to those from clean training data, across multiple layers of the network. It requires no labels or auxiliary tasks, and is applicable to any architecture and task.
- Parameters:
model (torch.nn.Module) – Model to be trained with TTT.
features_layer_names (List[str] | str) – List of layer names to be used for feature alignment.
optimization_parameters (dict) – The optimization parameters for the engine.
- Example:
from torch_ttt.engine.actmad_engine import ActMADEngine model = MyModel() engine = ActMADEngine(model, ["fc1", "fc2"]) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute statistics for features alignment engine.compute_statistics(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“ActMAD: Activation Matching to Align Distributions for Test-Time Training”, M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof
Paper link: PDF
- compute_statistics(dataloader: DataLoader) None [source]#
Extract and compute reference statistics for features.
- Parameters:
dataloader (DataLoader) – The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
- Raises:
ValueError – If the dataloader is empty or features have mismatched dimensions.