EataEngine#
- class torch_ttt.engine.eata_engine.EataEngine(model: Module, optimization_parameters: Dict[str, Any] | None = None, e_margin: float = 2.4538776394910684, d_margin: float = 0.05, fisher_alpha: float = 2000.0, fishers: Dict[str, Tuple[Tensor, Tensor]] | None = None)[source]#
EATA: Efficient Test-Time Adaptation without Forgetting.
EATA improves test-time robustness by minimizing entropy on confident, diverse predictions, while preserving prior knowledge via Fisher regularization. Only BatchNorm affine parameters are updated.
- Parameters:
model (torch.nn.Module) – Model to be adapted at test-time.
optimization_parameters (dict, optional) – Optimizer configuration.
e_margin (float) – Entropy threshold to filter uncertain predictions.
d_margin (float) – Cosine similarity threshold to filter redundant samples.
fisher_alpha (float) – Weight for Fisher-based regularization.
fishers (dict, optional) – Precomputed Fisher information for regularization.
- Example:
from torch_ttt.engine.eata_engine import EataEngine model = MyModel() engine = EataEngine(model, {"lr": 1e-3}) optimizer = torch.optim.SGD(engine.parameters(), lr=1e-3) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Compute diagonal Fisher Information engine.compute_fishers(train_loader) # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“Efficient Test-Time Model Adaptation without Forgetting”, Yuting Zhang, Srikrishna Karanam, Terrance E. Boult, Terrence Chen, Nuno Vasconcelos ICML 2022
Paper link: PDF
- collect_bn_params()[source]#
Collects affine parameters (scale and shift) from all BatchNorm2d layers.
- Returns:
List of parameters (weight and bias) to adapt during test-time training.
Corresponding parameter names for tracking or regularization.
- Return type:
Tuple[List[torch.nn.Parameter], List[str]]
- compute_fishers(data_loader: DataLoader, loss_fn: Module | None = None, lr: float = 0.001)[source]#
Estimate diagonal Fisher Information and store to self.fishers.
- Parameters:
data_loader – DataLoader over source data.
loss_fn – Loss to use for computing Fisher (default: CrossEntropy).
lr – Learning rate for dummy optimizer (used to collect parameters).
logger – Optional logger to print progress.