TorchOpt Optimizer

Optimizer(params, impl)

A base class for classic optimizers that similar to torch.optim.Optimizer.

MetaOptimizer(module, impl)

The base class for high-level differentiable optimizers.

Optimizer

class torchopt.Optimizer(params, impl)[source]

Bases: object

A base class for classic optimizers that similar to torch.optim.Optimizer.

The init() function.

Parameters:
  • params (iterable of torch.Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided by alias.py or a customized chain provided by combine.py. Note that using Optimizer(sgd()) or Optimizer(chain(sgd())) is equivalent to torchopt.SGD.

__init__(params, impl)[source]

The init() function.

Parameters:
  • params (iterable of torch.Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided by alias.py or a customized chain provided by combine.py. Note that using Optimizer(sgd()) or Optimizer(chain(sgd())) is equivalent to torchopt.SGD.

zero_grad(set_to_none=False)[source]

Sets the gradients of all optimized torch.Tensors to zero.

The behavior is similar to torch.optim.Optimizer.zero_grad().

Parameters:

set_to_none (bool) – Instead of setting to zero, set the grads to None.

Return type:

None

state_dict()[source]

Returns the state of the optimizer.

Return type:

Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...]

load_state_dict(state_dict)[source]

Loads the optimizer state.

Parameters:

state_dict (Sequence[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]) – Optimizer state. Should be an object returned from a call to state_dict().

Return type:

None

step(closure=None)[source]

Performs a single optimization step.

The behavior is similar to torch.optim.Optimizer.step().

Parameters:

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

Return type:

Optional[Tensor]

add_param_group(params)[source]

Add a param group to the optimizer’s param_groups.

Return type:

None

MetaOptimizer

class torchopt.MetaOptimizer(module, impl)[source]

Bases: object

The base class for high-level differentiable optimizers.

The init() function.

Parameters:
  • module (Module) – (nn.Module) A network whose parameters should be optimized.

  • impl (GradientTransformation) – (GradientTransformation) A low level optimizer function, it could be a optimizer function provided by alias.py or a customized chain provided by combine.py. Note that using MetaOptimizer(sgd(moment_requires_grad=True)) or MetaOptimizer(chain(sgd(moment_requires_grad=True))) is equivalent to torchopt.MetaSGD.

__init__(module, impl)[source]

The init() function.

Parameters:
  • module (Module) – (nn.Module) A network whose parameters should be optimized.

  • impl (GradientTransformation) – (GradientTransformation) A low level optimizer function, it could be a optimizer function provided by alias.py or a customized chain provided by combine.py. Note that using MetaOptimizer(sgd(moment_requires_grad=True)) or MetaOptimizer(chain(sgd(moment_requires_grad=True))) is equivalent to torchopt.MetaSGD.

step(loss)[source]

Compute the gradients of the loss to the network parameters and update network parameters.

Graph of the derivative will be constructed, allowing to compute higher order derivative products. We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update the network parameters without modifying tensors in-place.

Parameters:

loss (Tensor) – (torch.Tensor) The loss that is used to compute the gradients to the network parameters.

Return type:

None

add_param_group(module)[source]

Add a param group to the optimizer’s state_groups.

Return type:

None

state_dict()[source]

Extract the references of the optimizer states.

Note that the states are references, so any in-place operations will change the states inside MetaOptimizer at the same time.

Return type:

Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...]

load_state_dict(state_dict)[source]

Load the references of the optimizer states.

Return type:

None


Functional Optimizers

FuncOptimizer(impl, *[, inplace])

A wrapper class to hold the functional optimizer.

adam([lr, betas, eps, weight_decay, ...])

The functional Adam optimizer.

sgd(lr[, momentum, dampening, weight_decay, ...])

The functional version of the canonical Stochastic Gradient Descent optimizer.

rmsprop([lr, alpha, eps, weight_decay, ...])

The functional version of the RMSProp optimizer.

adamw([lr, betas, eps, weight_decay, ...])

Adam with weight decay regularization.

Wrapper for Function Optimizer

class torchopt.FuncOptimizer(impl, *, inplace=False)[source]

Bases: object

A wrapper class to hold the functional optimizer.

This wrapper makes it easier to maintain the optimizer states. The optimizer states are held by the wrapper internally. The wrapper provides a step() function to compute the gradients and update the parameters.

See also

The init() function.

Parameters:
  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided by alias.py or a customized chain provided by combine.py.

  • inplace (optional) – (default: False) The default value of inplace for each optimization update.

__init__(impl, *, inplace=False)[source]

The init() function.

Parameters:
  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided by alias.py or a customized chain provided by combine.py.

  • inplace (optional) – (default: False) The default value of inplace for each optimization update.

step(loss, params, inplace=None)[source]

Compute the gradients of loss to the network parameters and update network parameters.

Graph of the derivative will be constructed, allowing to compute higher order derivative products. We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update the network parameters without modifying tensors in-place.

Parameters:
  • loss (Tensor) – (torch.Tensor) loss that is used to compute the gradients to network parameters.

  • params (Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]) – (tree of torch.Tensor) An tree of torch.Tensors. Specifies what tensors should be optimized.

  • inplace (optional) – (default: None) Whether to update the parameters in-place. If None, use the default value specified in the constructor.

Return type:

Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]

state_dict()[source]

Extract the references of the optimizer states.

Note that the states are references, so any in-place operations will change the states inside FuncOptimizer at the same time.

Return type:

Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]

load_state_dict(state_dict)[source]

Load the references of the optimizer states.

Return type:

None

Functional Adam Optimizer

torchopt.adam(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]

The functional Adam optimizer.

Adam is an SGD variant with learning rate adaptation. The learning rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages).

References

Parameters:
  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-3) This is a fixed global scaling factor.

  • betas (Tuple[float, float]) – (default: (0.9, 0.999)) Coefficients used for computing running averages of gradient and its square.

  • eps (float) – (default: 1e-8) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • eps_root (float) – (default: 0.0) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • moment_requires_grad (bool) – (default: False) If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

  • use_accelerated_op (bool) – (default: False) If True use our implemented fused operator.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional AdamW Optimizer

torchopt.adamw(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]

Adam with weight decay regularization.

AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.

References

Parameters:
  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-3) This is a fixed global scaling factor.

  • betas (Tuple[float, float]) – (default: (0.9, 0.999)) Coefficients used for computing running averages of gradient and its square.

  • eps (float) – (default: 1e-8) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 1e-2) Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • eps_root (float) – (default: 0.0) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mask (Optional[Union[Any, Callable[[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Any]]]) – (default: None) A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

  • moment_requires_grad (bool) – (default: False) If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

  • use_accelerated_op (bool) – (default: False) If True use our implemented fused operator.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional SGD Optimizer

torchopt.sgd(lr, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, *, moment_requires_grad=False, maximize=False)[source]

The functional version of the canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

References

Parameters:
  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – This is a fixed global scaling factor.

  • momentum (float) – (default: 0.0) The decay rate used by the momentum term. The momentum is not used when it is set to 0.0.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • dampening (float) – (default: 0.0) Dampening for momentum.

  • nesterov (bool) – (default: False) Whether to use Nesterov momentum.

  • moment_requires_grad (bool) – (default: False) If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional RMSProp Optimizer

torchopt.rmsprop(lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]

The functional version of the RMSProp optimizer.

RMSProp is an SGD variant with learning rate adaptation. The learning rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.

References

Parameters:
  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-2) This is a fixed global scaling factor.

  • alpha (float) – (default: 0.99) Smoothing constant, the decay used to track the magnitude of previous gradients.

  • eps (float) – (default: 1e-8) A small numerical constant to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • momentum (float) – (default: 0.0) The decay rate used by the momentum term. The momentum is not used when it is set to 0.0.

  • centered (bool) – (default: False) If True, use the variance of the past gradients to rescale the latest gradients.

  • initial_scale (float) – (default: 0.0) Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses 0.0, TensorFlow 1.x uses 1.0. When reproducing results from a paper, verify the value used by the authors.

  • nesterov (bool) – (default: False) Whether to use Nesterov momentum.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.


Classic Optimizers

Adam(params, lr[, betas, eps, weight_decay, ...])

The classic Adam optimizer.

SGD(params, lr[, momentum, weight_decay, ...])

The classic SGD optimizer.

RMSProp(params[, lr, alpha, eps, ...])

The classic RMSProp optimizer.

AdamW(params[, lr, betas, eps, ...])

The classic AdamW optimizer.

Classic Adam Optimizer

class torchopt.Adam(params, lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, maximize=False, use_accelerated_op=False)[source]

Bases: Optimizer

The classic Adam optimizer.

See also

The init() function.

Parameters:
  • params (Iterable[Tensor]) – (iterable of torch.Tensor) An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-3) This is a fixed global scaling factor.

  • betas (Tuple[float, float]) – (default: (0.9, 0.999)) Coefficients used for computing running averages of gradient and its square.

  • eps (float) – (default: 1e-8) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • eps_root (float) – (default: 0.0) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

  • use_accelerated_op (bool) – (default: False) If True use our implemented fused operator.

Classic AdamW Optimizer

class torchopt.AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, maximize=False, use_accelerated_op=False)[source]

Bases: Optimizer

The classic AdamW optimizer.

See also

The init() function.

Parameters:
  • params (Iterable[Tensor]) – (iterable of torch.Tensor) An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-3) This is a fixed global scaling factor.

  • betas (Tuple[float, float]) – (default: (0.9, 0.999)) Coefficients used for computing running averages of gradient and its square.

  • eps (float) – (default: 1e-8) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 1e-2) Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • eps_root (float) – (default: 0.0) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mask (Optional[Union[Any, Callable[[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Any]]]) – (default: None) A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

  • use_accelerated_op (bool) – (default: False) If True use our implemented fused operator.

Classic SGD Optimizer

class torchopt.SGD(params, lr, momentum=0.0, weight_decay=0.0, dampening=0.0, nesterov=False, maximize=False)[source]

Bases: Optimizer

The classic SGD optimizer.

See also

The init() function.

Parameters:
  • params (Iterable[Tensor]) – (iterable of torch.Tensor) An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – This is a fixed global scaling factor.

  • momentum (float) – (default: 0.0) The decay rate used by the momentum term. The momentum is not used when it is set to 0.0.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • dampening (float) – (default: 0.0) Dampening for momentum.

  • nesterov (bool) – (default: False) Whether to use Nesterov momentum.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

Classic RMSProp Optimizer

class torchopt.RMSProp(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]

Bases: Optimizer

The classic RMSProp optimizer.

See also

The init function.

Parameters:
  • params (Iterable[Tensor]) – (iterable of torch.Tensor) An iterable of torch.Tensors. Specifies what Tensors should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-2) This is a fixed global scaling factor.

  • alpha (float) – (default: 0.99) Smoothing constant, the decay used to track the magnitude of previous gradients.

  • eps (float) – (default: 1e-8) A small numerical constant to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • momentum (float) – (default: 0.0) The decay rate used by the momentum term. The momentum is not used when it is set to 0.0.

  • centered (bool) – (default: False) If True, use the variance of the past gradients to rescale the latest gradients.

  • initial_scale (float) – (default: 0.0) Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses 0.0, TensorFlow 1.x uses 1.0. When reproducing results from a paper, verify the value used by the authors.

  • nesterov (bool) – (default: False) Whether to use Nesterov momentum.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.


Differentiable Meta-Optimizers

MetaAdam(module[, lr, betas, eps, ...])

The differentiable Adam optimizer.

MetaSGD(module, lr[, momentum, ...])

The differentiable Stochastic Gradient Descent optimizer.

MetaRMSProp(module[, lr, alpha, eps, ...])

The differentiable RMSProp optimizer.

MetaAdamW(module[, lr, betas, eps, ...])

The differentiable AdamW optimizer.

Differentiable Meta-Adam Optimizer

class torchopt.MetaAdam(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, moment_requires_grad=True, maximize=False, use_accelerated_op=False)[source]

Bases: MetaOptimizer

The differentiable Adam optimizer.

See also

The init() function.

Parameters:
  • module (Module) – (nn.Module) A network whose parameters should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-3) This is a fixed global scaling factor.

  • betas (Tuple[float, float]) – (default: (0.9, 0.999)) Coefficients used for computing running averages of gradient and its square.

  • eps (float) – (default: 1e-8) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • eps_root (float) – (default: 0.0) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • moment_requires_grad (bool) – (default: True) If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

  • use_accelerated_op (bool) – (default: False) If True use our implemented fused operator.

Differentiable Meta-AdamW Optimizer

class torchopt.MetaAdamW(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]

Bases: MetaOptimizer

The differentiable AdamW optimizer.

See also

The init() function.

Parameters:
  • module (Module) – (nn.Module) A network whose parameters should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-3) This is a fixed global scaling factor.

  • betas (Tuple[float, float]) – (default: (0.9, 0.999)) Coefficients used for computing running averages of gradient and its square.

  • eps (float) – (default: 1e-8) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 1e-2) Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • eps_root (float) – (default: 0.0) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mask (Optional[Union[Any, Callable[[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Any]]]) – (default: None) A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

  • moment_requires_grad (bool) – (default: False) If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

  • use_accelerated_op (bool) – (default: False) If True use our implemented fused operator.

Differentiable Meta-SGD Optimizer

class torchopt.MetaSGD(module, lr, momentum=0.0, weight_decay=0.0, dampening=0.0, nesterov=False, moment_requires_grad=True, maximize=False)[source]

Bases: MetaOptimizer

The differentiable Stochastic Gradient Descent optimizer.

See also

The init() function.

Parameters:
  • module (Module) – (nn.Module) A network whose parameters should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – This is a fixed global scaling factor.

  • momentum (float) – (default: 0.0) The decay rate used by the momentum term. The momentum is not used when it is set to 0.0.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • dampening (float) – (default: 0.0) Dampening for momentum.

  • nesterov (bool) – (default: False) Whether to use Nesterov momentum.

  • moment_requires_grad (bool) – (default: True) If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.

Differentiable Meta-RMSProp Optimizer

class torchopt.MetaRMSProp(module, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]

Bases: MetaOptimizer

The differentiable RMSProp optimizer.

See also

The init() function.

Parameters:
  • module (Module) – (nn.Module) A network whose parameters should be optimized.

  • lr (Union[float, Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]]) – (default: 1e-2) This is a fixed global scaling factor.

  • alpha (float) – (default: 0.99) Smoothing constant, the decay used to track the magnitude of previous gradients.

  • eps (float) – (default: 1e-8) A small numerical constant to avoid dividing by zero when rescaling.

  • weight_decay (float) – (default: 0.0) Weight decay, add L2 penalty to parameters.

  • momentum (float) – (default: 0.0) The decay rate used by the momentum term. The momentum is not used when it is set to 0.0.

  • centered (bool) – (default: False) If True, use the variance of the past gradients to rescale the latest gradients.

  • initial_scale (float) – (default: 0.0) Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses 0.0, TensorFlow 1.x uses 1.0. When reproducing results from a paper, verify the value used by the authors.

  • nesterov (bool) – (default: False) Whether to use Nesterov momentum.

  • maximize (bool) – (default: False) Maximize the params based on the objective, instead of minimizing.


Implicit differentiation

custom_root(optimality_fn, argnums[, ...])

Decorator for adding implicit differentiation to a root solver.

nn.ImplicitMetaGradientModule()

The base class for differentiable implicit meta-gradient models.

Custom solvers

torchopt.diff.implicit.custom_root(optimality_fn, argnums, has_aux=False, solve=functools.partial(<function _solve_normal_cg>))[source]

Decorator for adding implicit differentiation to a root solver.

This wrapper should be used as a decorator:

def optimality_fn(optimal_params, ...):
    ...
    return residual

@custom_root(optimality_fn, argnums=argnums)
def solver_fn(params, arg1, arg2, ...):
    ...
    return optimal_params

optimal_params = solver_fn(init_params, ...)

The first argument to optimality_fn and solver_fn is preserved as the parameter input. The argnums argument refers to the indices of the variables in solver_fn’s signature. For example, setting argnums=(1, 2) will compute the gradient of optimal_params with respect to arg1 and arg2 in the above snippet. Note that, except the first argument, the keyword arguments of the optimality_fn should be a subset of the ones of solver_fn. In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.

Parameters:
  • optimality_fn (Callable[..., Union[Tensor, Sequence[Tensor]]]) – (callable) An equation function, optimality_fn(params, *args). The invariant is optimality_fn(solution, *args) == 0 at the solution / root of solution.

  • argnums (Union[int, Tuple[int, ...]]) – (int or tuple of ints) Specifies arguments to compute gradients with respect to. The argnums can be an integer or a tuple of integers, which respect to the zero-based indices of the arguments of the solver_fn(params, *args) function. The argument params is included for the counting, while it is indexed as argnums=0.

  • has_aux (bool) – (default: False) Whether the decorated solver function returns auxiliary data.

  • solve (Callable[..., Union[Tensor, Sequence[Tensor]]]) – (callable, optional, default: linear_solve.solve_normal_cg()) a linear solver of the form solve(matvec, b).

Return type:

Callable[[Callable[..., Union[Tensor, Sequence[Tensor], Tuple[Union[Tensor, Sequence[Tensor]], Any]]]], Callable[..., Union[Tensor, Sequence[Tensor], Tuple[Union[Tensor, Sequence[Tensor]], Any]]]]

Returns:

A solver function decorator, i.e., custom_root(optimality_fn)(solver_fn).

Implicit Meta-Gradient Module

class torchopt.diff.implicit.nn.ImplicitMetaGradientModule[source]

Bases: MetaGradientModule

The base class for differentiable implicit meta-gradient models.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

linear_solve: Optional[Callable[[Callable[[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]
classmethod __init_subclass__(linear_solve=None)[source]

Validates and initializes the subclass.

Return type:

None

solve(*input, **kwargs)[source]

Solves the inner optimization problem.

Warning

For gradient-based optimization methods, the parameter inputs should be explicitly specified in the torch.autograd.backward() function as argument inputs. Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors (including the meta-parameters) that were used to compute the objective output. Alternatively, please use torch.autograd.grad() instead.

Example:

def solve(self, batch, labels):
    parameters = tuple(self.parameters())
    optimizer = torch.optim.Adam(parameters, lr=1e-3)
    with torch.enable_grad():
        for _ in range(100):
            loss = self.objective(batch, labels)
            optimizer.zero_grad()
            # Only update the `.grad` attribute for parameters
            # and leave the meta-parameters unchanged
            loss.backward(inputs=parameters)
            optimizer.step()
    return self
Return type:

Any

optimality(*input, **kwargs)[source]

Computes the optimality residual.

This method stands for the optimality residual to the optimal parameters after solving the inner optimization problem (solve()), i.e.:

module.solve(*input, **kwargs)
module.optimality(*input, **kwargs)  # -> 0

1. For gradient-based optimization, the optimality() function is the KKT condition, usually it is the gradients of the objective() function with respect to the module parameters (not the meta-parameters). If this method is not implemented, it will be automatically derived from the gradient of the objective() function.

\[\text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}\]

where \(\boldsymbol{x}\) is the joint vector of the module parameters and \(\boldsymbol{\theta}\) is the joint vector of the meta-parameters.

References

2. For fixed point iteration, the optimality() function can be the residual of the parameters between iterations, i.e.:

\[\text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}\]

where \(\boldsymbol{x}\) is the joint vector of the module parameters and \(\boldsymbol{\theta}\) is the joint vector of the meta-parameters.

Return type:

Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]

Returns:

A tree of tensors, the optimality residual to the optimal parameters after solving the inner optimization problem.

objective(*input, **kwargs)[source]

Computes the objective function value.

This method is used to calculate the optimality() if it is not implemented. Otherwise, this method is optional.

Return type:

Tensor

Returns:

A scalar tensor (dim=0), the objective function value.

__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_custom_objective': <class 'bool'>, '_custom_optimality': <class 'bool'>, '_forward_hooks': 'Dict[int, Callable]', '_forward_pre_hooks': 'Dict[int, Callable]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_meta_inputs': 'MetaInputsContainer', '_meta_modules': 'Dict[str, Optional[nn.Module]]', '_meta_parameters': 'Dict[str, Optional[torch.Tensor]]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_version': 'int', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'linear_solve': typing.Optional[typing.Callable[[typing.Callable[[typing.Union[torch.Tensor, typing.Tuple[ForwardRef('TensorTree'), ...], typing.List[ForwardRef('TensorTree')], typing.Dict[typing.Any, ForwardRef('TensorTree')], typing.Deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]], typing.Union[torch.Tensor, typing.Tuple[ForwardRef('TensorTree'), ...], typing.List[ForwardRef('TensorTree')], typing.Dict[typing.Any, ForwardRef('TensorTree')], typing.Deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]], typing.Union[torch.Tensor, typing.Tuple[ForwardRef('TensorTree'), ...], typing.List[ForwardRef('TensorTree')], typing.Dict[typing.Any, ForwardRef('TensorTree')], typing.Deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]], typing.Union[torch.Tensor, typing.Tuple[ForwardRef('TensorTree'), ...], typing.List[ForwardRef('TensorTree')], typing.Dict[typing.Any, ForwardRef('TensorTree')], typing.Deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]]], 'training': 'bool'}

Linear system solvers

solve_cg(**kwargs)

A wrapper that returns a solver function to solve A x = b using conjugate gradient.

solve_normal_cg(**kwargs)

A wrapper that returns a solver function to solve A^T A x = A^T b using conjugate gradient.

solve_inv(**kwargs)

A wrapper that returns a solver function to solve A x = b using matrix inversion.

Indirect solvers

torchopt.linear_solve.solve_cg(**kwargs)[source]

A wrapper that returns a solver function to solve A x = b using conjugate gradient.

This assumes that A is a hermitian, positive definite matrix.

Parameters:
  • ridge – Optional ridge regularization. Solves the equation for (A + ridge * I) @ x = b.

  • init – Optional initialization to be used by conjugate gradient.

  • **kwargs – Additional keyword arguments for the conjugate gradient solver torchopt.linalg.cg().

Returns:

A solver function with signature (matvec, b) -> x that solves A x = b using conjugate gradient where matvec(v) = A v.

See also

Conjugate gradient iteration torchopt.linalg.cg().

Example:

>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
...     return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']}
>>> b = matvec(x)
>>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)})
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
torchopt.linear_solve.solve_normal_cg(**kwargs)[source]

A wrapper that returns a solver function to solve A^T A x = A^T b using conjugate gradient.

This can be used to solve A x = b using conjugate gradient when A is not hermitian, positive definite.

Parameters:
  • ridge – Optional ridge regularization. Solves the equation for (A.T @ A + ridge * I) @ x = A.T @ b.

  • init – Optional initialization to be used by normal conjugate gradient.

  • **kwargs – Additional keyword arguments for the conjugate gradient solver torchopt.linalg.cg().

Returns:

A solver function with signature (matvec, b) -> x that solves A^T A x = A^T b using conjugate gradient where matvec(v) = A v.

See also

Conjugate gradient iteration torchopt.linalg.cg().

Example:

>>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
...     return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']}
>>> b = matvec(x)
>>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)})
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
torchopt.linear_solve.solve_inv(**kwargs)[source]

A wrapper that returns a solver function to solve A x = b using matrix inversion.

If ns = False, this assumes the matrix A is a constant matrix and will materialize it in memory.

Parameters:
  • ridge – Optional ridge regularization. Solves the equation for (A + ridge * I) @ x = b.

  • ns – Whether to use Neumann Series matrix inversion approximation. If False, materialize the matrix A in memory and use torch.linalg.solve() instead.

  • **kwargs – Additional keyword arguments for the Neumann Series matrix inversion approximation solver torchopt.linalg.ns().

Returns:

A solver function with signature (matvec, b) -> x that solves A x = b using matrix inversion where matvec(v) = A v.

See also

Neumann Series matrix inversion approximation torchopt.linalg.ns().

Example:

>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
...     return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']}
>>> b = matvec(x)
>>> solver = solve_inv(ns=True, maxiter=10)
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])

Optimizer Hooks

register_hook(hook)

Stateless identity transformation that leaves input gradients untouched.

zero_nan_hook(g)

A zero nan hook to replace nan with zero.

nan_to_num_hook([nan, posinf, neginf])

Returns a nan to num hook to replace nan / +inf / -inf with the given numbers.

Hook

torchopt.hook.register_hook(hook)[source]

Stateless identity transformation that leaves input gradients untouched.

This function passes through the gradient updates unchanged.

Return type:

GradientTransformation

Returns:

An (init_fn, update_fn) tuple.

torchopt.hook.zero_nan_hook(g)[source]

A zero nan hook to replace nan with zero.

Return type:

Tensor

torchopt.hook.nan_to_num_hook(nan=0.0, posinf=None, neginf=None)[source]

Returns a nan to num hook to replace nan / +inf / -inf with the given numbers.

Return type:

Callable[[Tensor], Tensor]


Gradient Transformation

clip_grad_norm(max_norm[, norm_type, ...])

Clips gradient norm of an iterable of parameters.

nan_to_num([nan, posinf, neginf])

Replaces updates with values nan / +inf / -inf to the given numbers.

Transforms

torchopt.clip_grad_norm(max_norm, norm_type=2.0, error_if_nonfinite=False)[source]

Clips gradient norm of an iterable of parameters.

Parameters:
  • max_norm (float or int) – The maximum absolute value for each element in the update.

  • norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.

  • error_if_nonfinite (bool) – if True, an error is thrown if the total norm of the gradients from updates is nan, inf, or -inf.

Return type:

GradientTransformation

Returns:

An (init_fn, update_fn) tuple.

torchopt.nan_to_num(nan=0.0, posinf=None, neginf=None)[source]

Replaces updates with values nan / +inf / -inf to the given numbers.

Return type:

GradientTransformation

Returns:

An (init_fn, update_fn) tuple.

Optimizer Schedules

linear_schedule(init_value, end_value, ...)

Alias polynomial schedule to linear schedule for convenience.

polynomial_schedule(init_value, end_value, ...)

Constructs a schedule with polynomial transition from init to end value.

Schedules

torchopt.schedule.linear_schedule(init_value, end_value, transition_steps, transition_begin=0)[source]

Alias polynomial schedule to linear schedule for convenience.

Return type:

Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]

torchopt.schedule.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]

Constructs a schedule with polynomial transition from init to end value.

Parameters:
  • init_value (Union[float, int, bool]) – Initial value for the scalar to be annealed.

  • end_value (Union[float, int, bool]) – End value of the scalar to be annealed.

  • power (Union[float, int, bool]) – The power of the polynomial used to transition from init to end.

  • transition_steps (int) – Number of steps over which annealing takes place, the scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin (int) – Must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Returns:

A function that maps step counts to values.

Return type:

schedule

Apply Parameter Updates

apply_updates(params, updates, *[, inplace])

Applies an update to the corresponding parameters.

Apply Updates

torchopt.apply_updates(params, updates, *, inplace=True)[source]

Applies an update to the corresponding parameters.

This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of GradientTransformations. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using tree_map() (e.g. if you want to manipulate updates in custom ways before applying them).

Parameters:
Return type:

Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]

Returns:

Updated parameters, with same structure, shape and type as params.

Combining Optimizers

chain(*transformations)

Applies a list of chainable update transformations.

Chain

torchopt.combine.chain(*transformations)[source]

Applies a list of chainable update transformations.

Given a sequence of chainable transforms, chain() returns an init_fn() that constructs a state by concatenating the states of the individual transforms, and returns an update_fn() which chains the update transformations feeding the appropriate state to each.

Parameters:

*transformations (GradientTransformation) – A sequence of chainable (init_fn, update_fn) tuples.

Return type:

GradientTransformation

Returns:

A single (init_fn, update_fn) tuple.

General Utilities

extract_state_dict()

Extract target state.

recover_state_dict(target, state)

Recover state.

stop_gradient(target)

Stop the gradient for the input object.

Extract State Dict

torchopt.extract_state_dict(target: Module, *, by: Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] = 'reference', device: Optional[Union[device, str, int]] = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '') ModuleState[source]
torchopt.extract_state_dict(target: MetaOptimizer, *, by: Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] = 'reference', device: Optional[Union[device, str, int]] = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '') Tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]], ...]

Extract target state.

Since a tensor use grad_fn to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is connected by grad_fn. Some algorithms requires manually detaching tensors from the computation graph.

Note that the extracted state is a reference, which means any in-place operator will affect the target that the state is extracted from.

Parameters:
  • target (Union[Module, MetaOptimizer]) – It could be a nn.Module or torchopt.MetaOptimizer.

  • by (Literal[‘reference’, ‘copy’, ‘deepcopy’, ‘ref’, ‘clone’, ‘deepclone’]) – The extract policy of tensors in the target. - 'reference': The extracted tensors will be references to the original tensors. - 'copy': The extracted tensors will be clones of the original tensors. This makes the copied tensors have grad_fn to be a <CloneBackward> function points to the original tensors. - 'deepcopy': The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph.

  • device (Optional[Union[device, str, int]]) – If specified, move the extracted state to the specified device.

  • with_buffers (bool) – Extract buffer together with parameters, this argument is only used if the input target is nn.Module.

  • detach_buffers (bool) – Whether to detach the reference to the buffers, this argument is only used if the input target is nn.Module and by='reference'.

  • enable_visual (bool) – Add additional annotations, which could be used in computation graph visualization. Currently, this flag only has effect on nn.Module but we will support torchopt.MetaOptimizer later.

  • visual_prefix (str) – Prefix for the visualization annotations.

Return type:

Union[ModuleState, Tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]], ...]]

Returns:

State extracted of the input object.

Recover State Dict

torchopt.recover_state_dict(target, state)[source]

Recover state.

This function is compatible for the extract_state.

Note that the recovering process is not in-place, so the tensors of the object will not be modified.

Parameters:
Return type:

None

Stop Gradient

torchopt.stop_gradient(target)[source]

Stop the gradient for the input object.

Since a tensor use grad_fn to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is connected by grad_fn. Some algorithms requires manually detaching tensors from the computation graph.

Note that the stop_gradient() operation is in-place.

Parameters:
Return type:

None

Visualizing Gradient Flow

make_dot(var[, params, show_attrs, ...])

Produces Graphviz representation of PyTorch autograd graph.

Make Dot

torchopt.visual.make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50)[source]

Produces Graphviz representation of PyTorch autograd graph.

If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green:

  • Blue

    Reachable leaf tensors that requires grad (tensors whose grad fields will be populated during backward()).

  • Orange

    Saved tensors of custom autograd functions as well as those saved by built-in backward nodes.

  • Green

    Tensor passed in as outputs.

  • Dark green

    If any output is a view, we represent its base tensor with a dark green node.

Parameters:
  • var (Union[Tensor, Sequence[Tensor]]) – Output tensor.

  • params (Optional[Union[Mapping[str, Tensor], ModuleState, Generator, Iterable[Union[Mapping[str, Tensor], ModuleState, Generator]]]]) – ([dict of (name, tensor) or state_dict]) Parameters to add names to node that requires grad.

  • show_attrs (bool) – Whether to display non-tensor attributes of backward nodes (Requires PyTorch version >= 1.9)

  • show_saved (bool) – Whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed. (Requires PyTorch version >= 1.9)

  • max_attr_chars (int) – If show_attrs is True, sets max number of characters to display for any given attribute.

Return type:

Digraph