Source code for torch_ttt.engine.masked_ttt_engine

from contextlib import contextmanager
import torch
from copy import deepcopy
from typing import Tuple, Dict, Any, List
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry

__all__ = ["MaskedTTTEngine"]


[docs] @EngineRegistry.register("masked_ttt") class MaskedTTTEngine(BaseEngine): r"""Masked token prediction-based **test-time training** engine. This engine performs self-supervised adaptation at inference by randomly masking input tokens (except those in `skip_tokens`) and training the model to reconstruct them using intermediate features. Args: model (torch.nn.Module): The model with a Transformer block to adapt at test time. mask_token_id (int): The token ID used for masking input tokens. features_layer_name (str): Name of the intermediate layer from which logits are extracted. mask_prob (float, optional): Probability of masking each token. Default is 0.15. skip_tokens (list, optional): List of token IDs to skip when applying the mask. Warning: The module with the name specified by :attr:`features_layer_name` must exist within the model. Warning: This method is designed for BERT-style transformer models that support masked token prediction. :Example: .. code-block:: python from torch_ttt.engine.masked_ttt_engine import MaskedTTTEngine model = MyTransformerModel() engine = MaskedTTTEngine(model, mask_token_id=103, features_layer_name="encoder.layer.11.output") optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4) # Training with TTT # (Important! MaskedTTTEngine can be applied to already pretrained models) engine.train() for batch in test_loader: optimizer.zero_grad() outputs, loss_ttt = engine.ttt_forward(batch) loss_ttt.backward() optimizer.step() # Inference engine.eval() for batch in test_loader: outputs, loss_ttt = engine.ttt_forward(batch) Reference: "Test-Time Training with Masked Autoencoders", Yossi Gandelsman, Yu Sun, Xinlei Chen, Alexei A. Efros Paper link: `PDF <https://papers.neurips.cc/paper_files/paper/2022/file/bcdec1c2d60f94a93b6e36f937aa0530-Paper-Conference.pdf>`_ """ def __init__( self, model: torch.nn.Module, mask_token_id: int, features_layer_name: str, mask_prob: float = 0.15, skip_tokens: List[int] =[], optimization_parameters: Dict[str, Any] = {}, ): super().__init__() self.model = model self.mask_token_id = mask_token_id self.features_layer_name = features_layer_name self.mask_prob = mask_prob self.skip_tokens = skip_tokens self.optimization_parameters = optimization_parameters self.loss = torch.nn.CrossEntropyLoss() # Locate and store the reference to the target module self.target_module = None for name, module in model.named_modules(): if name == features_layer_name: self.target_module = module break if self.target_module is None: raise ValueError(f"Module '{features_layer_name}' not found in the model.") def _mask_tokens(self, input_ids): """ Randomly masks tokens in the input tensor for masked token prediction. Args: input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len). Returns: Tuple[torch.Tensor, torch.Tensor]: - masked_input_ids: Token IDs with some positions replaced by `mask_token_id`. - mask: Boolean tensor indicating which positions were masked. """ masksed_inputs_ids = input_ids.clone() device = masksed_inputs_ids.device mask = torch.rand(masksed_inputs_ids.shape).to(device) < self.mask_prob for skip_token in self.skip_tokens: mask = mask & (masksed_inputs_ids != skip_token) masksed_inputs_ids[mask] = self.mask_token_id return masksed_inputs_ids, mask
[docs] def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]: """ Performs a masked token prediction forward pass for test-time training. Args: inputs (dict): A dictionary containing model inputs (must include `input_ids`). Returns: Returns cross-entropy loss computed on masked tokens. """ # TODO: don't make the copy inputs_copy = deepcopy(inputs) inputs_idx = inputs_copy["input_ids"] masked_inputs_idx, mask = self._mask_tokens(inputs_idx) inputs_copy["input_ids"] = masked_inputs_idx with self.__capture_hook() as features_hook: outputs = self.model(**inputs_copy) features = features_hook.output loss = self.loss( features[mask], inputs_idx[mask] ) if mask.sum() == 0: loss = features.mean() * 0 # return differentiable non-informative 0 loss return outputs, loss
@contextmanager def __capture_hook(self): """ Context manager to capture intermediate features via a forward hook. Yields: OutputHook: An object with `.output` attribute containing the captured tensor. """ class OutputHook: def __init__(self): self.output = None def hook(self, module, input, output): self.output = output features_hook = OutputHook() hook_handle = self.target_module.register_forward_hook(features_hook.hook) try: yield features_hook finally: hook_handle.remove()