MaskedTTTEngine#

class torch_ttt.engine.masked_ttt_engine.MaskedTTTEngine(model: Module, mask_token_id: int, features_layer_name: str, mask_prob: float = 0.15, skip_tokens: List[int] = [], optimization_parameters: Dict[str, Any] = {})[source]#

Masked token prediction-based test-time training engine.

This engine performs masked language modeling (MLM) as a self-supervised auxiliary task during inference. It randomly masks input tokens (except those in skip_tokens) and trains the model to predict them using intermediate features.

Parameters:
  • model (torch.nn.Module) – The model with a Transformer block to adapt at test time.

  • mask_token_id (int) – The token ID used for masking input tokens.

  • features_layer_name (str) – Name of the intermediate layer from which logits are extracted.

  • mask_prob (float, optional) – Probability of masking each token. Default is 0.15.

  • skip_tokens (list, optional) – List of token IDs to skip when applying the mask.

Warning

The module with the name specified by features_layer_name must exist within the model.

Example:

from torch_ttt.engine.masked_ttt_engine import MaskedTTTEngine

model = MyTransformerModel()
engine = MaskedTTTEngine(model, mask_token_id=103, features_layer_name="encoder.layer.11.output")

optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)

# Training with TTT (Important! MaskedTTTEngine can be applied to already pretrained models)
engine.train()
for batch in test_loader:
    optimizer.zero_grad()
    outputs, loss_ttt = engine.ttt_forward(batch)
    loss_ttt.backward()
    optimizer.step()

# Inference
engine.eval()
for batch in test_loader:
    outputs, loss_ttt = engine.ttt_forward(batch)
Reference:

This approach is inspired by masked language modeling and related TTT methods, e.g., MAE and self-supervised transformers.