EataEngine#

class torch_ttt.engine.eata_engine.EataEngine(model: Module, optimization_parameters: Dict[str, Any] | None = None, e_margin: float = 2.4538776394910684, d_margin: float = 0.05, fisher_alpha: float = 2000.0, fishers: Dict[str, Tuple[Tensor, Tensor]] | None = None)[source]#

EATA: Efficient Test-Time Adaptation without Forgetting.

EATA improves test-time robustness by minimizing entropy on confident, diverse predictions, while preserving prior knowledge via Fisher regularization. Only BatchNorm affine parameters are updated.

Parameters:
  • model (torch.nn.Module) – Model to be adapted at test-time.

  • optimization_parameters (dict, optional) – Optimizer configuration.

  • e_margin (float) – Entropy threshold to filter uncertain predictions.

  • d_margin (float) – Cosine similarity threshold to filter redundant samples.

  • fisher_alpha (float) – Weight for Fisher-based regularization.

  • fishers (dict, optional) – Precomputed Fisher information for regularization.

Example:

from torch_ttt.engine.eata_engine import EataEngine

model = MyModel()
engine = EataEngine(model, {"lr": 1e-3})
optimizer = torch.optim.SGD(engine.parameters(), lr=1e-3)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Compute diagonal Fisher Information
engine.compute_fishers(train_loader)

# Inference
engine.eval()
for inputs, labels in test_loader:
    output, loss_ttt = engine(inputs)

Reference:

“Efficient Test-Time Model Adaptation without Forgetting”, Yuting Zhang, Srikrishna Karanam, Terrance E. Boult, Terrence Chen, Nuno Vasconcelos ICML 2022

Paper link: PDF

collect_bn_params()[source]#

Collects affine parameters (scale and shift) from all BatchNorm2d layers.

Returns:

  • List of parameters (weight and bias) to adapt during test-time training.

  • Corresponding parameter names for tracking or regularization.

Return type:

Tuple[List[torch.nn.Parameter], List[str]]

compute_fishers(data_loader: DataLoader, loss_fn: Module | None = None, lr: float = 0.001)[source]#

Estimate diagonal Fisher Information and store to self.fishers.

Parameters:
  • data_loader – DataLoader over source data.

  • loss_fn – Loss to use for computing Fisher (default: CrossEntropy).

  • lr – Learning rate for dummy optimizer (used to collect parameters).

  • logger – Optional logger to print progress.

ttt_forward(inputs) Tuple[Tensor, Tensor][source]#

Forward pass and test-time loss computation.

Parameters:

inputs (torch.Tensor) – Input batch.

Returns:

Returns the current prediction, and the filtered and regularized adaptation loss.