Source code for torch_ttt.engine.actmad_engine
import torch
from typing import List, Dict, Any, Tuple, Union
from contextlib import contextmanager
from torch.utils.data import DataLoader
from torch_ttt.engine.base_engine import BaseEngine
from torch_ttt.engine_registry import EngineRegistry
__all__ = ["ActMADEngine"]
# TODO: add cuda support
[docs]
@EngineRegistry.register("actmad")
class ActMADEngine(BaseEngine):
"""**ActMAD** approach: multi-level pixel-wise feature alignment.
Args:
model (torch.nn.Module): Model to be trained with TTT.
features_layer_names (List[str] | str): List of layer names to be used for feature alignment.
optimization_parameters (dict): The optimization parameters for the engine.
:Example:
.. code-block:: python
from torch_ttt.engine.actmad_engine import ActMADEngine
model = MyModel()
engine = ActMADEngine(model, ["fc1", "fc2"])
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
# Training
engine.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs, loss_ttt = engine(inputs)
loss = criterion(outputs, labels) + alpha * loss_ttt
loss.backward()
optimizer.step()
# Compute statistics for features alignment
engine.compute_statistics(train_loader)
# Inference
engine.eval()
for inputs, labels in test_loader:
output, loss_ttt = engine(inputs)
Reference:
"ActMAD: Activation Matching to Align Distributions for Test-Time Training", M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof
Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_
"""
def __init__(
self,
model: torch.nn.Module,
features_layer_names: Union[List[str], str],
optimization_parameters: Dict[str, Any] = {},
):
super().__init__()
self.model = model
self.features_layer_names = features_layer_names
self.optimization_parameters = optimization_parameters
if isinstance(features_layer_names, str):
self.features_layer_names = [features_layer_names]
# TODO: rewrite this
self.target_modules = []
for layer_name in self.features_layer_names:
layer_exists = False
for name, module in model.named_modules():
if name == layer_name:
layer_exists = True
self.target_modules.append(module)
break
if not layer_exists:
raise ValueError(f"Layer {layer_name} does not exist in the model.")
self.reference_mean = None
self.reference_var = None
[docs]
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of the model.
Args:
inputs (torch.Tensor): Input tensor.
Returns:
Returns the current model prediction and rotation loss value.
"""
with self.__capture_hook() as features_hooks:
outputs = self.model(inputs)
features = [hook.output for hook in features_hooks]
# don't compute loss during training
if self.training:
return outputs, 0
if self.reference_var is None or self.reference_mean is None:
raise ValueError(
"Reference statistics are not computed. Please call `compute_statistics` method."
)
l1_loss = torch.nn.L1Loss(reduction='mean')
features_means = [torch.mean(feature, dim=0) for feature in features]
features_vars = [torch.var(feature, dim=0) for feature in features]
loss = 0
for i in range(len(self.target_modules)):
print(features_means[i].device, self.reference_mean[i].device)
loss += l1_loss(features_means[i], self.reference_mean[i])
loss += l1_loss(features_vars[i], self.reference_var[i])
return outputs, loss
[docs]
def compute_statistics(self, dataloader: DataLoader) -> None:
"""Extract and compute reference statistics for features.
Args:
dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
Raises:
ValueError: If the dataloader is empty or features have mismatched dimensions.
"""
self.model.eval()
feat_stack = [[] for _ in self.target_modules]
# TODO: compute variance in more memory efficient way
with torch.no_grad():
device = next(self.model.parameters()).device
for sample in dataloader:
if len(sample) < 1:
raise ValueError("Dataloader returned an empty batch.")
inputs = sample[0].to(device)
with self.__capture_hook() as features_hooks:
_ = self.model(inputs)
features = [hook.output.cpu() for hook in features_hooks]
for i, feature in enumerate(features):
feat_stack[i].append(feature)
# Compute mean and variance
self.reference_mean = [torch.mean(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
self.reference_var = [torch.var(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
@contextmanager
def __capture_hook(self):
"""Context manager to capture features via a forward hook."""
class OutputHook:
def __init__(self):
self.output = None
def hook(self, module, input, output):
self.output = output
hook_handels = []
features_hooks = []
for module in self.target_modules:
hook = OutputHook()
features_hooks.append(hook)
hook_handle = module.register_forward_hook(hook.hook)
hook_handels.append(hook_handle)
try:
yield features_hooks
finally:
for hook in hook_handels:
hook.remove()