Source code for torch_ttt.engine.masked_ttt_engine

from contextlib import contextmanager
import torch
from copy import deepcopy
from typing import Tuple, Dict, Any, List
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["MaskedTTTEngine"]


[docs] @EngineRegistry.register("masked_ttt") class MaskedTTTEngine(BaseEngine): r"""Masked token prediction-based **test-time training** engine. This engine performs masked language modeling (MLM) as a self-supervised auxiliary task during inference. It randomly masks input tokens (except those in `skip_tokens`) and trains the model to predict them using intermediate features. Args: model (torch.nn.Module): The model with a Transformer block to adapt at test time. mask_token_id (int): The token ID used for masking input tokens. features_layer_name (str): Name of the intermediate layer from which logits are extracted. mask_prob (float, optional): Probability of masking each token. Default is 0.15. skip_tokens (list, optional): List of token IDs to skip when applying the mask. Warning: The module with the name specified by :attr:`features_layer_name` must exist within the model. :Example: .. code-block:: python from torch_ttt.engine.masked_ttt_engine import MaskedTTTEngine model = MyTransformerModel() engine = MaskedTTTEngine(model, mask_token_id=103, features_layer_name="encoder.layer.11.output") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training with TTT (Important! MaskedTTTEngine can be applied to already pretrained models) engine.train() for batch in test_loader: optimizer.zero_grad() outputs, loss_ttt = engine.ttt_forward(batch) loss_ttt.backward() optimizer.step() # Inference engine.eval() for batch in test_loader: outputs, loss_ttt = engine.ttt_forward(batch) Reference: This approach is inspired by masked language modeling and related TTT methods, e.g., MAE and self-supervised transformers. """ def __init__( self, model: torch.nn.Module, mask_token_id: int, features_layer_name: str, mask_prob: float = 0.15, skip_tokens: List[int] =[], optimization_parameters: Dict[str, Any] = {}, ): super().__init__() self.model = model self.mask_token_id = mask_token_id self.features_layer_name = features_layer_name self.mask_prob = mask_prob self.skip_tokens = skip_tokens self.optimization_parameters = optimization_parameters self.loss = torch.nn.CrossEntropyLoss() # Locate and store the reference to the target module self.target_module = None for name, module in model.named_modules(): if name == features_layer_name: self.target_module = module break if self.target_module is None: raise ValueError(f"Module '{features_layer_name}' not found in the model.") def _mask_tokens(self, input_ids): """ Args: input_ids: (batch_size, seq_len) Returns: input_ids: (batch_size, seq_len) labels: (batch_size, seq_len) mask: (batch_size, seq_len) """ masksed_inputs_ids = input_ids.clone() device = masksed_inputs_ids.device mask = torch.rand(masksed_inputs_ids.shape).to(device) < self.mask_prob for skip_token in self.skip_tokens: mask = mask & (masksed_inputs_ids != skip_token) masksed_inputs_ids[mask] = self.mask_token_id return masksed_inputs_ids, mask def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: don't make the copy inputs_copy = deepcopy(inputs) inputs_idx = inputs_copy["input_ids"] masked_inputs_idx, mask = self._mask_tokens(inputs_idx) inputs_copy["input_ids"] = masked_inputs_idx with self.__capture_hook() as features_hook: outputs = self.model(**inputs_copy) features = features_hook.output loss = self.loss( features[mask], inputs_idx[mask] ) if mask.sum() == 0: loss = features.mean() * 0 # return differentiable non-informative 0 loss return outputs, loss @contextmanager def __capture_hook(self): """Context manager to capture features via a forward hook.""" class OutputHook: def __init__(self): self.output = None def hook(self, module, input, output): self.output = output features_hook = OutputHook() hook_handle = self.target_module.register_forward_hook(features_hook.hook) try: yield features_hook finally: hook_handle.remove()