EataEngine#
- class torch_ttt.engine.eata_engine.EataEngine(model: Module, optimization_parameters: Dict[str, Any] | None = None, e_margin: float = 2.4538776394910684, d_margin: float = 0.05, fisher_alpha: float = 2000.0, fishers: Dict[str, Tuple[Tensor, Tensor]] | None = None)[source]#
EATA: Efficient Test-Time Adaptation without Forgetting.
- Reference:
“Efficient Test-Time Model Adaptation without Forgetting” (ICML 2022) Zhang et al.
- compute_fishers(data_loader: DataLoader, loss_fn: Module | None = None, lr: float = 0.001)[source]#
Estimate diagonal Fisher Information and store to self.fishers.
- Parameters:
data_loader – DataLoader over source data.
loss_fn – Loss to use for computing Fisher (default: CrossEntropy).
lr – Learning rate for dummy optimizer (used to collect parameters).
logger – Optional logger to print progress.