MaskedTTTEngine#
- class torch_ttt.engine.masked_ttt_engine.MaskedTTTEngine(model: Module, mask_token_id: int, features_layer_name: str, mask_prob: float = 0.15, skip_tokens: List[int] = [], optimization_parameters: Dict[str, Any] = {})[source]#
Masked token prediction-based test-time training engine.
This engine performs masked language modeling (MLM) as a self-supervised auxiliary task during inference. It randomly masks input tokens (except those in skip_tokens) and trains the model to predict them using intermediate features.
- Parameters:
model (torch.nn.Module) – The model with a Transformer block to adapt at test time.
mask_token_id (int) – The token ID used for masking input tokens.
features_layer_name (str) – Name of the intermediate layer from which logits are extracted.
mask_prob (float, optional) – Probability of masking each token. Default is 0.15.
skip_tokens (list, optional) – List of token IDs to skip when applying the mask.
Warning
The module with the name specified by
features_layer_name
must exist within the model.- Example:
from torch_ttt.engine.masked_ttt_engine import MaskedTTTEngine model = MyTransformerModel() engine = MaskedTTTEngine(model, mask_token_id=103, features_layer_name="encoder.layer.11.output") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training with TTT (Important! MaskedTTTEngine can be applied to already pretrained models) engine.train() for batch in test_loader: optimizer.zero_grad() outputs, loss_ttt = engine.ttt_forward(batch) loss_ttt.backward() optimizer.step() # Inference engine.eval() for batch in test_loader: outputs, loss_ttt = engine.ttt_forward(batch)
- Reference:
This approach is inspired by masked language modeling and related TTT methods, e.g., MAE and self-supervised transformers.