ActMADEngine#

class torch_ttt.engine.actmad_engine.ActMADEngine(model: Module, features_layer_names: List[str] | str, optimization_parameters: Dict[str, Any] = {})[source]#

ActMAD approach: multi-level pixel-wise feature alignment.

Parameters:
  • model (torch.nn.Module) – Model to be trained with TTT.

  • features_layer_names (List[str] | str) – List of layer names to be used for feature alignment.

  • optimization_parameters (dict) – The optimization parameters for the engine.

Example:

from torch_ttt.engine.actmad_engine import ActMADEngine

model = MyModel()
engine = ActMADEngine(model, ["fc1", "fc2"])
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Compute statistics for features alignment
engine.compute_statistics(train_loader)

# Inference
engine.eval()
for inputs, labels in test_loader:
    output, loss_ttt = engine(inputs)

Reference:

“ActMAD: Activation Matching to Align Distributions for Test-Time Training”, M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof

Paper link: PDF

compute_statistics(dataloader: DataLoader) None[source]#

Extract and compute reference statistics for features.

Parameters:

dataloader (DataLoader) – The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.

Raises:

ValueError – If the dataloader is empty or features have mismatched dimensions.

ttt_forward(inputs) Tuple[Tensor, Tensor][source]#

Forward pass of the model.

Parameters:

inputs (torch.Tensor) – Input tensor.

Returns:

Returns the current model prediction and rotation loss value.