MaskedTTTEngine#

class torch_ttt.engine.masked_ttt_engine.MaskedTTTEngine(model: Module, mask_token_id: int, features_layer_name: str, mask_prob: float = 0.15, skip_tokens: List[int] = [], optimization_parameters: Dict[str, Any] = {})[source]#

Masked token prediction-based test-time training engine.

This engine performs self-supervised adaptation at inference by randomly masking input tokens (except those in skip_tokens) and training the model to reconstruct them using intermediate features.

Parameters:
  • model (torch.nn.Module) – The model with a Transformer block to adapt at test time.

  • mask_token_id (int) – The token ID used for masking input tokens.

  • features_layer_name (str) – Name of the intermediate layer from which logits are extracted.

  • mask_prob (float, optional) – Probability of masking each token. Default is 0.15.

  • skip_tokens (list, optional) – List of token IDs to skip when applying the mask.

Warning

The module with the name specified by features_layer_name must exist within the model.

Warning

This method is designed for BERT-style transformer models that support masked token prediction.

Example:

from torch_ttt.engine.masked_ttt_engine import MaskedTTTEngine

model = MyTransformerModel()
engine = MaskedTTTEngine(model, mask_token_id=103, features_layer_name="encoder.layer.11.output")

optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training with TTT
# (Important! MaskedTTTEngine can be applied to already pretrained models)
engine.train()
for batch in test_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine.ttt_forward(batch)
    loss_ttt.backward()
    optimizer.step()

# Inference
engine.eval()
for batch in test_loader:
    outputs, loss_ttt = engine.ttt_forward(batch)

Reference:

“Test-Time Training with Masked Autoencoders”, Yossi Gandelsman, Yu Sun, Xinlei Chen, Alexei A. Efros

Paper link: PDF

ttt_forward(inputs) Tuple[Tensor, Tensor][source]#

Performs a masked token prediction forward pass for test-time training.

Parameters:

inputs (dict) – A dictionary containing model inputs (must include input_ids).

Returns:

Returns cross-entropy loss computed on masked tokens.