ActMADEngine#

class torch_ttt.engine.actmad_engine.ActMADEngine(model: Module, features_layer_names: List[str] | str, optimization_parameters: Dict[str, Any] = {})[source]#

ActMAD approach: multi-level pixel-wise feature alignment.

ActMAD adapts models at test-time by aligning activation statistics (means and variances) of the test inputs to those from clean training data, across multiple layers of the network. It requires no labels or auxiliary tasks, and is applicable to any architecture and task.

Parameters:
  • model (torch.nn.Module) – Model to be trained with TTT.

  • features_layer_names (List[str] | str) – List of layer names to be used for feature alignment.

  • optimization_parameters (dict) – The optimization parameters for the engine.

Example:

from torch_ttt.engine.actmad_engine import ActMADEngine

model = MyModel()
engine = ActMADEngine(model, ["fc1", "fc2"])
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Compute statistics for features alignment
engine.compute_statistics(train_loader)

# Inference
engine.eval()
for inputs, labels in test_loader:
    output, loss_ttt = engine(inputs)

Reference:

“ActMAD: Activation Matching to Align Distributions for Test-Time Training”, M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof

Paper link: PDF

compute_statistics(dataloader: DataLoader) None[source]#

Extract and compute reference statistics for features.

Parameters:

dataloader (DataLoader) – The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.

Raises:

ValueError – If the dataloader is empty or features have mismatched dimensions.

ttt_forward(inputs) Tuple[Tensor, Tensor][source]#

Forward pass of the model.

Parameters:

inputs (torch.Tensor) – Input tensor.

Returns:

The current model prediction and the alignment loss based on activation statistics.