Source code for torchopt.diff.implicit.nn.module
# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The base class for differentiable implicit meta-gradient models."""
# pylint: disable=redefined-builtin
from __future__ import annotations
import abc
import functools
import inspect
import itertools
from typing import Any, Iterable
import functorch
import torch
from torchopt.diff.implicit.decorator import custom_root
from torchopt.nn.module import MetaGradientModule
from torchopt.nn.stateless import reparametrize, swap_state
from torchopt.typing import LinearSolver, TupleOfTensors
__all__ = ['ImplicitMetaGradientModule']
def _stateless_objective_fn(
flat_params: TupleOfTensors,
flat_meta_params: TupleOfTensors,
params_names: Iterable[str],
meta_params_names: Iterable[str],
self: ImplicitMetaGradientModule,
/,
*input: Any,
**kwargs: Any,
) -> torch.Tensor:
with reparametrize(
self,
itertools.chain(
zip(params_names, flat_params),
zip(meta_params_names, flat_meta_params),
),
):
return self.objective(*input, **kwargs)
def _stateless_optimality_fn(
flat_params: TupleOfTensors,
flat_meta_params: TupleOfTensors,
params_names: Iterable[str],
meta_params_names: Iterable[str],
self: ImplicitMetaGradientModule,
/,
*input: Any,
**kwargs: Any,
) -> TupleOfTensors:
with reparametrize(
self,
itertools.chain(
zip(params_names, flat_params),
zip(meta_params_names, flat_meta_params),
),
):
return self.optimality(*input, **kwargs)
def make_optimality_from_objective(
cls: type[ImplicitMetaGradientModule],
) -> type[ImplicitMetaGradientModule]:
"""Derive the optimality function of the objective function."""
static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective')
static_cls_objective = inspect.getattr_static(cls, 'objective', static_super_objective)
if static_cls_objective is static_super_objective:
raise TypeError('The objective function is not defined.')
def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors:
named_params = tuple(self.named_parameters())
named_meta_params = tuple(self.named_meta_parameters())
if len(named_params) == 0:
raise RuntimeError('The module has no parameters.')
if len(named_meta_params) == 0:
raise RuntimeError('The module has no meta-parameters.')
params_names, flat_params = tuple(zip(*named_params))
meta_params_names, flat_meta_params = tuple(zip(*named_meta_params))
objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0)
return objective_grad_fn(
flat_params,
flat_meta_params,
params_names,
meta_params_names,
self,
*input,
**kwargs,
)
cls.optimality = optimality # type: ignore[method-assign]
return cls
def enable_implicit_gradients(
cls: type[ImplicitMetaGradientModule],
) -> type[ImplicitMetaGradientModule]:
"""Enable implicit gradients for the :func:`solve` method."""
cls_solve = cls.solve
if getattr(cls_solve, '__implicit_gradients_enabled__', False):
raise TypeError('Implicit gradients are already enabled for the `solve` method.')
solve_kwargs = {'solve': cls.linear_solve} if cls.linear_solve is not None else {}
@custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs)
def stateless_solver_fn(
# pylint: disable=unused-argument
flat_params: TupleOfTensors,
flat_meta_params: TupleOfTensors,
params_names: Iterable[str],
meta_params_names: Iterable[str],
# pylint: enable=unused-argument
self: ImplicitMetaGradientModule,
/,
*input: Any,
**kwargs: Any,
) -> tuple[TupleOfTensors, Any]:
"""Solve the optimization problem."""
output = cls_solve(self, *input, **kwargs)
flat_optimal_params = tuple(p.detach_() for p in self.parameters())
return flat_optimal_params, output
@functools.wraps(cls_solve)
def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any:
"""Solve the optimization problem."""
named_params = tuple(self.named_parameters())
named_meta_params = tuple(self.named_meta_parameters())
if len(named_params) == 0:
raise RuntimeError('The module has no parameters.')
if len(named_meta_params) == 0:
raise RuntimeError('The module has no meta-parameters.')
params_names, flat_params = tuple(zip(*named_params))
meta_params_names, flat_meta_params = tuple(zip(*named_meta_params))
flat_optimal_params, output = stateless_solver_fn(
flat_params,
flat_meta_params,
params_names,
meta_params_names,
self,
*input,
**kwargs,
)
swap_state(self, zip(params_names, flat_optimal_params))
return output
wrapped.__implicit_gradients_enabled__ = True # type: ignore[attr-defined]
cls.solve = wrapped # type: ignore[method-assign]
return cls
[docs]
class ImplicitMetaGradientModule(MetaGradientModule, metaclass=abc.ABCMeta):
"""The base class for differentiable implicit meta-gradient models."""
_custom_optimality: bool
_custom_objective: bool
linear_solve: LinearSolver | None
[docs]
def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None:
"""Validate and initialize the subclass."""
super().__init_subclass__()
cls.linear_solve = linear_solve
static_super_optimality = inspect.getattr_static(ImplicitMetaGradientModule, 'optimality')
static_super_objective = inspect.getattr_static(ImplicitMetaGradientModule, 'objective')
static_cls_optimality = inspect.getattr_static(cls, 'optimality')
static_cls_objective = inspect.getattr_static(cls, 'objective')
cls._custom_optimality = static_cls_optimality is not static_super_optimality
cls._custom_objective = static_cls_objective is not static_super_objective
if cls._custom_optimality:
if isinstance(static_cls_optimality, staticmethod):
raise TypeError('method optimality() must not be a staticmethod.')
if isinstance(static_cls_optimality, classmethod):
raise TypeError('method optimality() must not be a classmethod.')
if not callable(static_cls_optimality):
raise TypeError('method optimality() must be callable.')
elif not cls._custom_objective:
raise TypeError(
'ImplicitMetaGradientModule requires either an optimality() method or an objective() method',
)
else:
if isinstance(static_cls_objective, staticmethod):
raise TypeError('method objective() must not be a staticmethod.')
if isinstance(static_cls_objective, classmethod):
raise TypeError('method objective() must not be a classmethod.')
if not callable(static_cls_objective):
raise TypeError('method objective() must be callable.')
make_optimality_from_objective(cls)
enable_implicit_gradients(cls)
[docs]
@abc.abstractmethod
def solve(self, *input: Any, **kwargs: Any) -> Any:
"""Solve the inner optimization problem.
.. warning::
For gradient-based optimization methods, the parameter inputs should be explicitly
specified in the :func:`torch.autograd.backward` function as argument ``inputs``.
Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors
(including the meta-parameters) that were used to compute the objective output.
Alternatively, please use :func:`torch.autograd.grad` instead.
Examples:
.. code-block:: python
def solve(self, batch, labels):
parameters = tuple(self.parameters())
optimizer = torch.optim.Adam(parameters, lr=1e-3)
with torch.enable_grad():
for _ in range(100):
loss = self.objective(batch, labels)
optimizer.zero_grad()
# Only update the `.grad` attribute for parameters
# and leave the meta-parameters unchanged
loss.backward(inputs=parameters)
optimizer.step()
return self
"""
raise NotImplementedError # update parameters
[docs]
def optimality(self, *input: Any, **kwargs: Any) -> TupleOfTensors:
r"""Compute the optimality residual.
This method stands for the optimality residual to the optimal parameters after solving the
inner optimization problem (:meth:`solve`), i.e.:
.. code-block:: python
module.solve(*input, **kwargs)
module.optimality(*input, **kwargs) # -> 0
1. For gradient-based optimization, the :meth:`optimality` function is the KKT condition,
usually it is the gradients of the :meth:`objective` function with respect to the module
parameters (not the meta-parameters). If this method is not implemented, it will be
automatically derived from the gradient of the :meth:`objective` function.
.. math::
\text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}
where :math:`\boldsymbol{x}` is the joint vector of the module parameters and
:math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters.
References:
- Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions
2. For fixed point iteration, the :meth:`optimality` function can be the residual of the
parameters between iterations, i.e.:
.. math::
\text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}
where :math:`\boldsymbol{x}` is the joint vector of the module parameters and
:math:`\boldsymbol{\theta}` is the joint vector of the meta-parameters.
Returns:
A tuple of tensors, the optimality residual to the optimal parameters after solving the
inner optimization problem. The returned tensors should correspond to the outputs of
`tuple(self.parameters())`.
""" # pylint: disable=line-too-long
raise NotImplementedError
[docs]
def objective(self, *input: Any, **kwargs: Any) -> torch.Tensor:
"""Compute the objective function value.
This method is used to calculate the :meth:`optimality` if it is not implemented.
Otherwise, this method is optional.
Returns:
A scalar tensor (``dim=0``), the objective function value.
"""
raise NotImplementedError