EataEngine#

class torch_ttt.engine.eata_engine.EataEngine(model: Module, optimization_parameters: Dict[str, Any] | None = None, e_margin: float = 2.4538776394910684, d_margin: float = 0.05, fisher_alpha: float = 2000.0, fishers: Dict[str, Tuple[Tensor, Tensor]] | None = None)[source]#

EATA: Efficient Test-Time Adaptation without Forgetting.

Reference:

“Efficient Test-Time Model Adaptation without Forgetting” (ICML 2022) Zhang et al.

collect_bn_params()[source]#

Collect affine scale + shift parameters from BatchNorm2d layers.

compute_fishers(data_loader: DataLoader, loss_fn: Module | None = None, lr: float = 0.001)[source]#

Estimate diagonal Fisher Information and store to self.fishers.

Parameters:
  • data_loader – DataLoader over source data.

  • loss_fn – Loss to use for computing Fisher (default: CrossEntropy).

  • lr – Learning rate for dummy optimizer (used to collect parameters).

  • logger – Optional logger to print progress.