DeYOEngine#
- class torch_ttt.engine.deyo_engine.DeYOEngine(model: Module, optimization_parameters: Dict[str, Any] | None = None, e_margin: float = 3.4538776394910684, plpd_thresh: float = 0.2, ent_norm: float = 2.763102111592855, patch_len: int = 4, reweight_ent: float = 1.0, reweight_plpd: float = 1.0)[source]#
DeYO: Destroy Your Object – Test-Time Adaptation with PLPD.
DeYO adapts models at test-time by combining entropy minimization with Patch Label Preservation Deviation (PLPD). It filters uncertain and unstable samples using entropy and patch perturbation sensitivity, updating normalization layers with confident ones.
- Parameters:
model (torch.nn.Module) – Model to be adapted at test-time.
optimization_parameters (dict, optional) – Optimizer configuration.
e_margin (float) – Entropy threshold for filtering uncertain samples.
plpd_thresh (float) – PLPD threshold to filter unstable predictions.
ent_norm (float) – Normalization constant for entropy weighting.
patch_len (int) – Number of patches per spatial dimension for input shuffling.
reweight_ent (float) – Scaling factor for entropy-based weighting.
reweight_plpd (float) – Scaling factor for PLPD-based weighting.
- Example:
from torch_ttt.engine.deyo_engine import DeYOEngine model = MyModel() engine = DeYOEngine(model, {"lr": 1e-3}) optimizer = torch.optim.Adam(engine.parameters(), lr=1e-3) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Inference engine.eval() for inputs, labels in test_loader: output, loss_ttt = engine(inputs)
Reference:
“Entropy is not Enough for Test-Time Adaptation: From the Perspective of Disentangled Factors”, Jonghyun Lee, Dahuin Jung, Saehyung Lee, Junsung Park, Juhyeon Shin, Uiwon Hwang, Sungroh Yoon
Paper link: PDF
- ttt_forward(inputs: Tensor) Tuple[Tensor, Tensor] [source]#
Forward pass and loss computation for test-time adaptation.
Selects confident and stable samples using entropy and PLPD filtering. Computes a weighted entropy loss over reliable inputs for adaptation.
- Parameters:
inputs (torch.Tensor) – Input tensor.
- Returns:
The current model prediction and PLPD-based loss.