IT3Engine#

class torch_ttt.engine.it3_engine.IT3Engine(model: Module, features_layer_name: str, embeder: Module | None = None, combine_fn: Callable[[Tensor, Tensor], Tensor] | None = None, distance_fn: Callable[[Tensor, Tensor], Tensor] | None = None, optimization_parameters: Dict[str, Any] = {})[source]#

IT³: Idempotent Test-Time Training

A domain-agnostic test-time training method that adapts model predictions by enforcing idempotence—ensuring that repeated applications of the model yield consistent outputs.

Parameters:
  • model (torch.nn.Module) – A pre-trained model to be adapted with IT3.

  • features_layer_name (str) – Name of the layer to inject features into during TTT.

  • embeder (torch.nn.Module, optional) – Module to embed the initial prediction. If None, it will be created dynamically.

  • combine_fn (Callable, optional) – Function to combine the hidden features and embedding. If None, uses broadcast addition.

  • distance_fn (Callable, optional) – Distance function used to compare successive predictions. Defaults to MSE loss.

  • optimization_parameters (dict) – Parameters controlling adaptation (e.g., learning rate, optimizer choice).

Warning

The module with name features_layer_name must exist in the model.

Note

If embeder and combine_fn are not provided, they are constructed on-the-fly to match the dimensions of predictions and the injection layer (supporting 2D/4D tensors).

Example:

from torch_ttt.engine.ittt_engine import IT3Engine

model = MyModel()
engine = IT3Engine(model, "encoder")
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training
engine.train()
for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine(inputs, target=labels)
    loss = criterion(outputs, labels) + alpha * loss_ttt
    loss.backward()
    optimizer.step()

# Inference
engine.eval()
for inputs in test_loader:
    outputs, loss_ttt = engine(inputs)

Reference:

“IT³: Idempotent Test-Time Training”, Nikita Durasov, Assaf Shocher, Doruk Oner, Gal Chechik, Alexei A. Efros, Pascal Fua. ICML 2025.

Paper link: PDF

ttt_forward(inputs, target=None)[source]#

Two forward passes; gradients enabled to allow adaptation.