IT3Engine#
- class torch_ttt.engine.it3_engine.IT3Engine(model: Module, features_layer_name: str, embeder: Module | None = None, combine_fn: Callable[[Tensor, Tensor], Tensor] | None = None, distance_fn: Callable[[Tensor, Tensor], Tensor] | None = None, optimization_parameters: Dict[str, Any] = {})[source]#
IT³: Idempotent Test-Time Training
A domain-agnostic test-time training method that adapts model predictions by enforcing idempotence—ensuring that repeated applications of the model yield consistent outputs.
- Parameters:
model (torch.nn.Module) – A pre-trained model to be adapted with IT3.
features_layer_name (str) – Name of the layer to inject features into during TTT.
embeder (torch.nn.Module, optional) – Module to embed the initial prediction. If None, it will be created dynamically.
combine_fn (Callable, optional) – Function to combine the hidden features and embedding. If None, uses broadcast addition.
distance_fn (Callable, optional) – Distance function used to compare successive predictions. Defaults to MSE loss.
optimization_parameters (dict) – Parameters controlling adaptation (e.g., learning rate, optimizer choice).
Warning
The module with name
features_layer_name
must exist in the model.Note
If embeder and combine_fn are not provided, they are constructed on-the-fly to match the dimensions of predictions and the injection layer (supporting 2D/4D tensors).
Example:
from torch_ttt.engine.ittt_engine import IT3Engine model = MyModel() engine = IT3Engine(model, "encoder") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training engine.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs, loss_ttt = engine(inputs, target=labels) loss = criterion(outputs, labels) + alpha * loss_ttt loss.backward() optimizer.step() # Inference engine.eval() for inputs in test_loader: outputs, loss_ttt = engine(inputs)
Reference:
“IT³: Idempotent Test-Time Training”, Nikita Durasov, Assaf Shocher, Doruk Oner, Gal Chechik, Alexei A. Efros, Pascal Fua. ICML 2025.
Paper link: PDF