# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for TorchOpt."""
from __future__ import annotations
import copy
import itertools
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Sequence, cast, overload
from typing_extensions import TypeAlias # Python 3.10+
import torch
import torch.nn as nn
from torchopt import pytree
from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree
from torchopt.optim.meta.base import MetaOptimizer
__all__ = [
class ModuleState(NamedTuple):
"""Container for module state."""
params: tuple[TensorContainer, ...]
buffers: tuple[TensorContainer, ...]
visual_contents: dict | None = None
detach_buffers: bool = False
CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone']
def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) -> None:
"""Stop the gradient for the input object.
Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the
backpropagated gradient will flow over the tensor and continue flow to the tensors that is
connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the
computation graph.
Note that the :func:`stop_gradient` operation is in-place.
target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The target that to be
detached from the computation graph, it could be a :class:`nn.Module`,
:class:`torchopt.MetaOptimizer`, state of the :class:`torchopt.MetaOptimizer`, or just
a plain list of tensors.
# pylint: disable-next=import-outside-toplevel
from torchopt.optim.meta.base import MetaOptimizer
def fn_(obj: Any) -> None:
if isinstance(obj, torch.Tensor):
requires_grad = obj.requires_grad
if isinstance(target, ModuleState):
true_target = cast(TensorTree, (target.params, target.buffers))
elif isinstance(target, nn.Module):
true_target = cast(TensorTree, tuple(target.parameters()))
elif isinstance(target, MetaOptimizer):
true_target = cast(TensorTree, target.state_dict())
true_target = cast(TensorTree, target) # tree of tensors
pytree.tree_map_(fn_, true_target)
def extract_state_dict( # pylint: disable=too-many-arguments
target: nn.Module,
by: CopyMode = 'reference',
device: Device | None = None,
with_buffers: bool = True,
detach_buffers: bool = False,
enable_visual: bool = False,
visual_prefix: str = '',
) -> ModuleState: # pragma: no cover
def extract_state_dict(
target: MetaOptimizer,
by: CopyMode = 'reference',
device: Device | None = None,
) -> tuple[OptState, ...]: # pragma: no cover
# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals
def extract_module_containers(
module: nn.Module,
with_buffers: bool = True,
) -> tuple[ModuleTensorContainers, ModuleTensorContainers]:
"""Extract the references to the containers of parameters and buffers from a module."""
if isinstance(module, nn.Module):
params: list[TensorContainer] = []
buffers: list[TensorContainer] = []
memo: set[nn.Module] = set()
def update_container(container: list[TensorContainer], items: TensorContainer) -> None:
if len(items) > 0:
container.append(items) # we need references to original dictionaries
# pylint: disable=protected-access
update_container(params, module._parameters) # type: ignore[arg-type]
if with_buffers:
update_container(buffers, module._buffers)
for submodule in module.modules():
if submodule in memo:
update_container(params, submodule._parameters) # type: ignore[arg-type]
if with_buffers:
update_container(buffers, submodule._buffers)
return tuple(params), tuple(buffers)
raise RuntimeError(f'Unexpected class of {module}')
def recover_state_dict(
target: nn.Module | MetaOptimizer,
state: ModuleState | Sequence[OptState],
) -> None:
"""Recover state.
This function is compatible for the ``extract_state``.
Note that the recovering process is not in-place, so the tensors of the object will not be
target (nn.Module or MetaOptimizer): Target that need to recover.
state (ModuleState or sequence of tree of Tensor): The recovering state.
# pylint: disable-next=import-outside-toplevel
from torchopt.optim.meta.base import MetaOptimizer
if isinstance(target, nn.Module):
params, buffers, *_ = state = cast(ModuleState, state)
params_containers, buffers_containers = extract_module_containers(target, with_buffers=True)
if state.detach_buffers:
def clone_detach_(t: torch.Tensor) -> torch.Tensor:
if isinstance(t, nn.Parameter):
return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad)
return t.clone().detach_().requires_grad_(t.requires_grad)
buffers = pytree.tree_map(clone_detach_, buffers) # type: ignore[assignment,arg-type]
for tgt, src in itertools.chain(
zip(params_containers, params),
zip(buffers_containers, buffers),
elif isinstance(target, MetaOptimizer):
state = cast(Sequence[OptState], state)
raise TypeError(f'Unexpected class of {target}')
def module_clone(
target: nn.Module,
by: CopyMode = 'reference',
detach_buffers: bool = False,
device: Device | None = None,
) -> nn.Module: # pragma: no cover
def module_clone(
target: MetaOptimizer,
by: CopyMode = 'reference',
detach_buffers: bool = False,
device: Device | None = None,
) -> MetaOptimizer: # pragma: no cover
def module_clone(
target: TensorTree,
by: CopyMode = 'reference',
detach_buffers: bool = False,
device: Device | None = None,
) -> TensorTree: # pragma: no cover
# pylint: disable-next=too-many-locals
def module_clone( # noqa: C901
target: nn.Module | MetaOptimizer | TensorTree,
by: CopyMode = 'reference',
detach_buffers: bool = False,
device: Device | None = None,
) -> nn.Module | MetaOptimizer | TensorTree:
"""Clone a module.
target (nn.Module, MetaOptimizer, or tree of Tensor): The target to be cloned.
by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`)
- :const:`'reference'`: The extracted tensors will be references to the original
- :const:`'copy'`: The extracted tensors will be clones of the original tensors. This
makes the copied tensors have ``grad_fn`` to be a ``<CloneBackward>`` function points
to the original tensors.
- :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original
tensors. The deep-copied tensors will detach from the original computation graph.
detach_buffers (bool, optional): Whether to detach the reference to the buffers, this
argument is only used if the input target is :class:`nn.Module` and ``by='reference'``.
(default: :const:`False`)
device (Device or None, optional): If specified, move the cloned module to the specified
device. (default: :const:`None`)
The cloned module.
assert by in ('reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone')
by = by.replace('clone', 'copy')
by = 'reference' if by == 'ref' else by
if device is not None:
device = torch.device(device)
# pylint: disable-next=import-outside-toplevel
from torchopt.optim.meta.base import MetaOptimizer
if isinstance(target, (nn.Module, MetaOptimizer)):
if isinstance(target, nn.Module):
containers = cast(TensorTree, extract_module_containers(target, with_buffers=True))
containers = cast(TensorTree, target.state_dict())
tensors = pytree.tree_leaves(containers)
memo = {id(t): t for t in tensors}
cloned = copy.deepcopy(target, memo=memo)
state = extract_state_dict( # type: ignore[call-overload]
recover_state_dict(cloned, state)
return cloned
# Tree of tensors
if device is not None:
target_device = torch.device(device)
def reference(t: torch.Tensor) -> torch.Tensor:
return t.to(device=target_device)
def clone(t: torch.Tensor) -> torch.Tensor:
return t.clone().to(device=target_device)
def clone_detach_(t: torch.Tensor) -> torch.Tensor:
if isinstance(t, nn.Parameter):
return nn.Parameter(
return t.clone().to(device=target_device).detach_().requires_grad_(t.requires_grad)
def reference(t: torch.Tensor) -> torch.Tensor:
return t
def clone(t: torch.Tensor) -> torch.Tensor:
return t.clone()
def clone_detach_(t: torch.Tensor) -> torch.Tensor:
if isinstance(t, nn.Parameter):
return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad)
return t.clone().detach_().requires_grad_(t.requires_grad)
if by == 'reference':
replicate = reference
elif by == 'copy':
replicate = clone
replicate = clone_detach_
return pytree.tree_map(replicate, cast(TensorTree, target))
def module_detach_(target: ModuleState) -> ModuleState: # pragma: no cover
def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover
def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover
def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover
def module_detach_(
target: ModuleState | nn.Module | MetaOptimizer | TensorTree,
) -> ModuleState | nn.Module | MetaOptimizer | TensorTree:
"""Detach a module from the computation graph.
target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The
target to be detached.
The detached module.
return target