TorchOpt Optimizer

Optimizer(params, impl)

A base class for classic optimizers that similar to torch.optim.Optimizer.

MetaOptimizer(module, impl)

The base class for high-level differentiable optimizers.

Optimizer

class torchopt.Optimizer(params, impl)[source]

Bases: object

A base class for classic optimizers that similar to torch.optim.Optimizer.

Initialize the optimizer.

Parameters
  • params (iterable of torch.Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in torchopt.alias or a customized torchopt.chain()ed transformation. Note that using Optimizer(sgd()) or Optimizer(chain(sgd())) is equivalent to torchopt.SGD.

__init__(params, impl)[source]

Initialize the optimizer.

Parameters
  • params (iterable of torch.Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in torchopt.alias or a customized torchopt.chain()ed transformation. Note that using Optimizer(sgd()) or Optimizer(chain(sgd())) is equivalent to torchopt.SGD.

zero_grad(set_to_none=False)[source]

Set the gradients of all optimized torch.Tensors to zero.

The behavior is similar to torch.optim.Optimizer.zero_grad().

Parameters

set_to_none (bool, optional) – Instead of setting to zero, set the grads to None. (default: False)

Return type

None

state_dict()[source]

Return the state of the optimizer.

Return type

tuple[OptState, …]

load_state_dict(state_dict)[source]

Load the optimizer state.

Parameters

state_dict (sequence of tree of Tensor) – Optimizer state. Should be an object returned from a call to state_dict().

Return type

None

step(closure=None)[source]

Perform a single optimization step.

The behavior is similar to torch.optim.Optimizer.step().

Parameters

closure (callable or None, optional) – A closure that reevaluates the model and returns the loss. Optional for most optimizers. (default: None)

Return type

Optional[Tensor]

add_param_group(params)[source]

Add a param group to the optimizer’s param_groups.

Return type

None

MetaOptimizer

class torchopt.MetaOptimizer(module, impl)[source]

Bases: object

The base class for high-level differentiable optimizers.

Initialize the meta-optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in torchopt.alias or a customized torchopt.chain()ed transformation. Note that using MetaOptimizer(sgd(moment_requires_grad=True)) or MetaOptimizer(chain(sgd(moment_requires_grad=True))) is equivalent to torchopt.MetaSGD.

__init__(module, impl)[source]

Initialize the meta-optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in torchopt.alias or a customized torchopt.chain()ed transformation. Note that using MetaOptimizer(sgd(moment_requires_grad=True)) or MetaOptimizer(chain(sgd(moment_requires_grad=True))) is equivalent to torchopt.MetaSGD.

step(loss)[source]

Compute the gradients of the loss to the network parameters and update network parameters.

Graph of the derivative will be constructed, allowing to compute higher order derivative products. We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update the network parameters without modifying tensors in-place.

Parameters

loss (torch.Tensor) – The loss that is used to compute the gradients to the network parameters.

Return type

None

add_param_group(module)[source]

Add a param group to the optimizer’s state_groups.

Return type

None

state_dict()[source]

Extract the references of the optimizer states.

Note that the states are references, so any in-place operations will change the states inside MetaOptimizer at the same time.

Return type

tuple[OptState, …]

load_state_dict(state_dict)[source]

Load the references of the optimizer states.

Return type

None


Functional Optimizers

FuncOptimizer(impl, *[, inplace])

A wrapper class to hold the functional optimizer.

adadelta([lr, rho, eps, weight_decay, ...])

Create a functional version of the AdaDelta optimizer.

adagrad([lr, lr_decay, weight_decay, ...])

Create a functional version of the AdaGrad optimizer.

adam([lr, betas, eps, weight_decay, ...])

Create a functional version of the Adam optimizer.

adamw([lr, betas, eps, weight_decay, ...])

Create a functional version of the Adam optimizer with weight decay regularization.

adamax([lr, betas, eps, weight_decay, ...])

Create a functional version of the AdaMax optimizer.

radam([lr, betas, eps, weight_decay, ...])

Create a functional version of the RAdam optimizer.

rmsprop([lr, alpha, eps, weight_decay, ...])

Create a functional version of the RMSProp optimizer.

sgd(lr[, momentum, dampening, weight_decay, ...])

Create a functional version of the canonical Stochastic Gradient Descent optimizer.

Wrapper for Function Optimizer

class torchopt.FuncOptimizer(impl, *, inplace=False)[source]

Bases: object

A wrapper class to hold the functional optimizer.

This wrapper makes it easier to maintain the optimizer states. The optimizer states are held by the wrapper internally. The wrapper provides a step() function to compute the gradients and update the parameters.

See also

Initialize the functional optimizer wrapper.

Parameters
  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in torchopt.alias or a customized torchopt.chain()ed transformation.

  • inplace (bool, optional) – The default value of inplace for each optimization update. (default: False)

__init__(impl, *, inplace=False)[source]

Initialize the functional optimizer wrapper.

Parameters
  • impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in torchopt.alias or a customized torchopt.chain()ed transformation.

  • inplace (bool, optional) – The default value of inplace for each optimization update. (default: False)

step(loss, params, inplace=None)[source]

Compute the gradients of loss to the network parameters and update network parameters.

Graph of the derivative will be constructed, allowing to compute higher order derivative products. We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update the network parameters without modifying tensors in-place.

Parameters
  • loss (Tensor) – The loss that is used to compute the gradients to network parameters.

  • params (tree of Tensor) – An tree of torch.Tensors. Specifies what tensors should be optimized.

  • inplace (bool or None, optional) – Whether to update the parameters in-place. If None, use the default value specified in the constructor. (default: None)

Return type

Params

state_dict()[source]

Extract the references of the optimizer states.

Note that the states are references, so any in-place operations will change the states inside FuncOptimizer at the same time.

Return type

OptState

load_state_dict(state_dict)[source]

Load the references of the optimizer states.

Return type

None

Functional AdaDelta Optimizer

torchopt.adadelta(lr=0.001, rho=0.9, eps=1e-06, weight_decay=0.0, *, moment_requires_grad=False)[source]

Create a functional version of the AdaDelta optimizer.

Adadelta is a per-dimension learning rate method for gradient descent.

References

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • rho (float, optional) – Coefficients used for computing running averages of gradient and its square. (default: 0.9)

  • eps (float, optional) – A small constant applied to the square root (as in the Adadelta paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional AdaGrad Optimizer

torchopt.adagrad(lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, *, maximize=False)[source]

Create a functional version of the AdaGrad optimizer.

AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training.

Warning

AdaGrad’s main limit is the monotonic accumulation of squared gradients in the denominator. Since all terms are > 0, the sum keeps growing during training, and the learning rate eventually becomes very small.

References

Duchi et al., 2011: https://jmlr.org/papers/v12/duchi11a.html

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-2)

  • lr_decay (float, optional) – Learning rate decay. (default: 0.0)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • initial_accumulator_value (float, optional) – Initial value for the accumulator. (default: 0.0)

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-10)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional Adam Optimizer

torchopt.adam(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]

Create a functional version of the Adam optimizer.

Adam is an SGD variant with learning rate adaptation. The learning rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages).

References

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

  • use_accelerated_op (bool, optional) – If True use our implemented fused operator. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional AdamW Optimizer

torchopt.adamw(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]

Create a functional version of the Adam optimizer with weight decay regularization.

AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.

References

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate. (default: 1e-2)

  • eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default: 0.0)

  • mask (tree of Tensor, callable, or None, optional) – A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters. (default: None)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

  • use_accelerated_op (bool, optional) – If True use our implemented fused operator. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional AdaMax Optimizer

torchopt.adamax(lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, *, moment_requires_grad=False)[source]

Create a functional version of the AdaMax optimizer.

References

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional RAdam Optimizer

torchopt.radam(lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, *, moment_requires_grad=False)[source]

Create a functional version of the RAdam optimizer.

RAdam is a variance of the adaptive learning rate rectified optimizer.

References

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional RMSProp Optimizer

torchopt.rmsprop(lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]

Create a functional version of the RMSProp optimizer.

RMSProp is an SGD variant with learning rate adaptation. The learning rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.

References

Parameters
  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-2)

  • alpha (float, optional) – Smoothing constant, the decay used to track the magnitude of previous gradients. (default: 0.99)

  • eps (float, optional) – A small numerical constant to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to 0.0. (default: 0.0)

  • centered (bool, optional) – If True, use the variance of the past gradients to rescale the latest gradients. (default: False)

  • initial_scale (float, optional) – Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses 0.0, TensorFlow 1.x uses 1.0. When reproducing results from a paper, verify the value used by the authors. (default: 0.0)

  • nesterov (bool, optional) – Whether to use Nesterov momentum. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.

Functional SGD Optimizer

torchopt.sgd(lr, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, *, moment_requires_grad=False, maximize=False)[source]

Create a functional version of the canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

References

Parameters
  • lr (float or callable) – This is a fixed global scaling factor or a learning rate scheduler.

  • momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to 0.0. (default: 0.0)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • dampening (float, optional) – Dampening for momentum. (default: 0.0)

  • nesterov (bool, optional) – Whether to use Nesterov momentum. (default: False)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Return type

GradientTransformation

Returns

The corresponding GradientTransformation instance.

See also

The functional optimizer wrapper torchopt.FuncOptimizer.


Classic Optimizers

AdaDelta(params[, lr, rho, eps, weight_decay])

The classic AdaDelta optimizer.

Adadelta

alias of AdaDelta

AdaGrad(params[, lr, lr_decay, ...])

The classic AdaGrad optimizer.

Adagrad

alias of AdaGrad

Adam(params[, lr, betas, eps, weight_decay, ...])

The classic Adam optimizer.

AdamW(params[, lr, betas, eps, ...])

The classic AdamW optimizer.

AdaMax(params[, lr, betas, eps, weight_decay])

The classic AdaMax optimizer.

Adamax

alias of AdaMax

RAdam(params[, lr, betas, eps, weight_decay])

The classic RAdam optimizer.

RMSProp(params[, lr, alpha, eps, ...])

The classic RMSProp optimizer.

SGD(params, lr[, momentum, weight_decay, ...])

The classic SGD optimizer.

Classic AdaDelta Optimizer

class torchopt.AdaDelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0.0)[source]

Bases: Optimizer

The classic AdaDelta optimizer.

See also

  • The functional AdaDelta optimizer: torchopt.adadelta().

  • The differentiable meta-AdaDelta optimizer: torchopt.MetaAdaDetla.

Initialize the AdaDelta optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • rho (float, optional) – Coefficients used for computing running averages of gradient and its square. (default: 0.9)

  • eps (float, optional) – A small constant applied to the square root (as in the AdaDelta paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

Classic AdaGrad Optimizer

class torchopt.AdaGrad(params, lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, *, maximize=False)[source]

Bases: Optimizer

The classic AdaGrad optimizer.

See also

Initialize the AdaGrad optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-2)

  • lr_decay (float, optional) – Learning rate decay. (default: 0.0)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • initial_accumulator_value (float, optional) – Initial value for the accumulator. (default: 0.0)

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-10)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Classic Adam Optimizer

class torchopt.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, maximize=False, use_accelerated_op=False)[source]

Bases: Optimizer

The classic Adam optimizer.

See also

Initialize the Adam optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

  • use_accelerated_op (bool, optional) – If True use our implemented fused operator. (default: False)

Classic AdamW Optimizer

class torchopt.AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, maximize=False, use_accelerated_op=False)[source]

Bases: Optimizer

The classic AdamW optimizer.

See also

Initialize the AdamW optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate. (default: 1e-2)

  • eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default: 0.0)

  • mask (tree of Tensor, callable, or None, optional) – A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters. (default: None)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

  • use_accelerated_op (bool, optional) – If True use our implemented fused operator. (default: False)

Classic AdaMax Optimizer

class torchopt.AdaMax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)[source]

Bases: Optimizer

The classic AdaMax optimizer.

See also

Initialize the AdaMax optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to the square root (as in the AdaMax paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

Classic RAdam Optimizer

class torchopt.RAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)[source]

Bases: Optimizer

The classic RAdam optimizer.

See also

Initialize the RAdam optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

Classic RMSProp Optimizer

class torchopt.RMSProp(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]

Bases: Optimizer

The classic RMSProp optimizer.

See also

Initialize the RMSProp optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-2)

  • alpha (float, optional) – Smoothing constant, the decay used to track the magnitude of previous gradients. (default: 0.99)

  • eps (float, optional) – A small numerical constant to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to 0.0. (default: 0.0)

  • centered (bool, optional) – If True, use the variance of the past gradients to rescale the latest gradients. (default: False)

  • initial_scale (float, optional) – Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses 0.0, TensorFlow 1.x uses 1.0. When reproducing results from a paper, verify the value used by the authors. (default: 0.0)

  • nesterov (bool, optional) – Whether to use Nesterov momentum. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Classic SGD Optimizer

class torchopt.SGD(params, lr, momentum=0.0, weight_decay=0.0, dampening=0.0, nesterov=False, maximize=False)[source]

Bases: Optimizer

The classic SGD optimizer.

See also

Initialize the SGD optimizer.

Parameters
  • params (iterable of Tensor) – An iterable of torch.Tensors. Specifies what tensors should be optimized.

  • lr (float or callable) – This is a fixed global scaling factor or a learning rate scheduler.

  • momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to 0.0. (default: 0.0)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • dampening (float, optional) – Dampening for momentum. (default: 0.0)

  • nesterov (bool, optional) – Whether to use Nesterov momentum. (default: False)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)


Differentiable Meta-Optimizers

MetaAdaDelta(module[, lr, rho, eps, ...])

The differentiable AdaDelta optimizer.

MetaAdadelta

alias of MetaAdaDelta

MetaAdaGrad(module[, lr, lr_decay, ...])

The differentiable AdaGrad optimizer.

MetaAdagrad

alias of MetaAdaGrad

MetaAdam(module[, lr, betas, eps, ...])

The differentiable Adam optimizer.

MetaAdamW(module[, lr, betas, eps, ...])

The differentiable AdamW optimizer.

MetaAdaMax(module[, lr, betas, eps, ...])

The differentiable AdaMax optimizer.

MetaAdamax

alias of MetaAdaMax

MetaRAdam(module[, lr, betas, eps, ...])

The differentiable RAdam optimizer.

MetaRMSProp(module[, lr, alpha, eps, ...])

The differentiable RMSProp optimizer.

MetaSGD(module, lr[, momentum, ...])

The differentiable Stochastic Gradient Descent optimizer.

Differentiable Meta-AdaDelta Optimizer

class torchopt.MetaAdaDelta(module, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0.0, *, moment_requires_grad=True)[source]

Bases: MetaOptimizer

The differentiable AdaDelta optimizer.

See also

  • The functional AdaDelta optimizer: torchopt.adadetla().

  • The classic AdaDelta optimizer: torchopt.Adadelta.

Initialize the meta AdaDelta optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • rho (float, optional) – Coefficients used for computing running averages of gradient and its square. (default: 0.9)

  • eps (float, optional) – A small constant applied to the square root (as in the AdaDelta paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

Differentiable Meta-AdaGrad Optimizer

class torchopt.MetaAdaGrad(module, lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, *, maximize=False)[source]

Bases: MetaOptimizer

The differentiable AdaGrad optimizer.

See also

  • The functional AdaGrad optimizer: torchopt.adagrad().

  • The classic AdaGrad optimizer: torchopt.Adagrad.

Initialize the meta AdaGrad optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-2)

  • lr_decay (float, optional) – Learning rate decay. (default: 0.0)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • initial_accumulator_value (float, optional) – Initial value for the accumulator. (default: 0.0)

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-10)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Differentiable Meta-Adam Optimizer

class torchopt.MetaAdam(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, moment_requires_grad=True, maximize=False, use_accelerated_op=False)[source]

Bases: MetaOptimizer

The differentiable Adam optimizer.

See also

Initialize the meta-Adam optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

  • use_accelerated_op (bool, optional) – If True use our implemented fused operator. (default: False)

Differentiable Meta-AdamW Optimizer

class torchopt.MetaAdamW(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]

Bases: MetaOptimizer

The differentiable AdamW optimizer.

See also

Initialize the meta-AdamW optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate. (default: 1e-2)

  • eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default: 0.0)

  • mask (tree of Tensor, callable, or None, optional) – A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters. (default: None)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

  • use_accelerated_op (bool, optional) – If True use our implemented fused operator. (default: False)

Differentiable Meta-AdaMax Optimizer

class torchopt.MetaAdaMax(module, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, moment_requires_grad=True)[source]

Bases: MetaOptimizer

The differentiable AdaMax optimizer.

See also

  • The functional AdaMax optimizer: torchopt.adamax().

  • The classic AdaMax optimizer: torchopt.Adamax.

Initialize the meta AdaMax optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to the square root (as in the AdaMax paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

Differentiable Meta-RAdam Optimizer

class torchopt.MetaRAdam(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, moment_requires_grad=True)[source]

Bases: MetaOptimizer

The differentiable RAdam optimizer.

See also

  • The functional RAdam optimizer: torchopt.radan().

  • The classic RAdam optimizer: torchopt.RAdam.

Initialize the meta-RAdam optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-3)

  • betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default: 1e-6)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

Differentiable Meta-RMSProp Optimizer

class torchopt.MetaRMSProp(module, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]

Bases: MetaOptimizer

The differentiable RMSProp optimizer.

See also

Initialize the meta-RMSProp optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default: 1e-2)

  • alpha (float, optional) – Smoothing constant, the decay used to track the magnitude of previous gradients. (default: 0.99)

  • eps (float, optional) – A small numerical constant to avoid dividing by zero when rescaling. (default: 1e-8)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to 0.0. (default: 0.0)

  • centered (bool, optional) – If True, use the variance of the past gradients to rescale the latest gradients. (default: False)

  • initial_scale (float, optional) – Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses 0.0, TensorFlow 1.x uses 1.0. When reproducing results from a paper, verify the value used by the authors. (default: 0.0)

  • nesterov (bool, optional) – Whether to use Nesterov momentum. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)

Differentiable Meta-SGD Optimizer

class torchopt.MetaSGD(module, lr, momentum=0.0, weight_decay=0.0, dampening=0.0, nesterov=False, moment_requires_grad=True, maximize=False)[source]

Bases: MetaOptimizer

The differentiable Stochastic Gradient Descent optimizer.

See also

Initialize the meta-SGD optimizer.

Parameters
  • module (nn.Module) – A network whose parameters should be optimized.

  • lr (float or callable) – This is a fixed global scaling factor or a learning rate scheduler.

  • momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to 0.0. (default: 0.0)

  • weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default: 0.0)

  • dampening (float, optional) – Dampening for momentum. (default: 0.0)

  • nesterov (bool, optional) – Whether to use Nesterov momentum. (default: False)

  • moment_requires_grad (bool, optional) – If True the momentums will be created with flag requires_grad=True, this flag is often used in Meta-Learning algorithms. (default: False)

  • maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default: False)


Implicit Differentiation

custom_root(optimality_fn, argnums[, ...])

Return a decorator for adding implicit differentiation to a root solver.

nn.ImplicitMetaGradientModule(*args, **kwargs)

The base class for differentiable implicit meta-gradient models.

Custom Solvers

torchopt.diff.implicit.custom_root(optimality_fn, argnums, has_aux=False, solve=None)[source]

Return a decorator for adding implicit differentiation to a root solver.

This wrapper should be used as a decorator:

def optimality_fn(optimal_params, ...):
    ...
    return residual

@custom_root(optimality_fn, argnums=argnums)
def solver_fn(params, arg1, arg2, ...):
    ...
    return optimal_params

optimal_params = solver_fn(init_params, ...)

The first argument to optimality_fn and solver_fn is preserved as the parameter input. The argnums argument refers to the indices of the variables in solver_fn’s signature. For example, setting argnums=(1, 2) will compute the gradient of optimal_params with respect to arg1 and arg2 in the above snippet. Note that, except the first argument, the keyword arguments of the optimality_fn should be a subset of the ones of solver_fn. In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.

Parameters
  • optimality_fn (callable) – An equation function, optimality_fn(params, *args). The invariant is optimality_fn(solution, *args) == 0 at the solution / root of solution.

  • argnums (int or tuple of int) – Specifies arguments to compute gradients with respect to. The argnums can be an integer or a tuple of integers, which respect to the zero-based indices of the arguments of the solver_fn(params, *args) function. The argument params is included for the counting, while it is indexed as argnums=0.

  • has_aux (bool, optional) – Whether the decorated solver function returns auxiliary data. (default: False)

  • solve (callable, optional) – A linear solver of the form solve(matvec, b). (default: linear_solve.solve_normal_cg())

Return type

Callable[[Callable[..., Union[Tensor, Sequence[Tensor], tuple[Union[Tensor, Sequence[Tensor]], Any]]]], Callable[..., Union[Tensor, Sequence[Tensor], tuple[Union[Tensor, Sequence[Tensor]], Any]]]]

Returns

A solver function decorator, i.e., custom_root(optimality_fn)(solver_fn).

Implicit Meta-Gradient Module

class torchopt.diff.implicit.nn.ImplicitMetaGradientModule(*args, **kwargs)[source]

Bases: MetaGradientModule

The base class for differentiable implicit meta-gradient models.

Initialize a new module instance.

linear_solve: LinearSolver | None
classmethod __init_subclass__(linear_solve=None)[source]

Validate and initialize the subclass.

Return type

None

abstract solve(*input, **kwargs)[source]

Solve the inner optimization problem. :rtype: Any

Warning

For gradient-based optimization methods, the parameter inputs should be explicitly specified in the torch.autograd.backward() function as argument inputs. Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors (including the meta-parameters) that were used to compute the objective output. Alternatively, please use torch.autograd.grad() instead.

Examples

def solve(self, batch, labels):
    parameters = tuple(self.parameters())
    optimizer = torch.optim.Adam(parameters, lr=1e-3)
    with torch.enable_grad():
        for _ in range(100):
            loss = self.objective(batch, labels)
            optimizer.zero_grad()
            # Only update the `.grad` attribute for parameters
            # and leave the meta-parameters unchanged
            loss.backward(inputs=parameters)
            optimizer.step()
    return self
optimality(*input, **kwargs)[source]

Compute the optimality residual.

This method stands for the optimality residual to the optimal parameters after solving the inner optimization problem (solve()), i.e.:

module.solve(*input, **kwargs)
module.optimality(*input, **kwargs)  # -> 0

1. For gradient-based optimization, the optimality() function is the KKT condition, usually it is the gradients of the objective() function with respect to the module parameters (not the meta-parameters). If this method is not implemented, it will be automatically derived from the gradient of the objective() function.

\[\text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}\]

where \(\boldsymbol{x}\) is the joint vector of the module parameters and \(\boldsymbol{\theta}\) is the joint vector of the meta-parameters.

References

2. For fixed point iteration, the optimality() function can be the residual of the parameters between iterations, i.e.:

\[\text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}\]

where \(\boldsymbol{x}\) is the joint vector of the module parameters and \(\boldsymbol{\theta}\) is the joint vector of the meta-parameters.

Return type

Tuple[Tensor, ...]

Returns

A tuple of tensors, the optimality residual to the optimal parameters after solving the inner optimization problem. The returned tensors should correspond to the outputs of tuple(self.parameters()).

objective(*input, **kwargs)[source]

Compute the objective function value.

This method is used to calculate the optimality() if it is not implemented. Otherwise, this method is optional.

Return type

Tensor

Returns

A scalar tensor (dim=0), the objective function value.

__abstractmethods__ = frozenset({'solve'})
__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_custom_objective': 'bool', '_custom_optimality': 'bool', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_meta_inputs': 'MetaInputsContainer', '_meta_modules': 'dict[str, nn.Module | None]', '_meta_parameters': 'TensorContainer', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'linear_solve': 'LinearSolver | None', 'training': 'bool'}

Linear System Solvers

solve_cg(**kwargs)

Return a solver function to solve A x = b using conjugate gradient.

solve_normal_cg(**kwargs)

Return a solver function to solve A^T A x = A^T b using conjugate gradient.

solve_inv(**kwargs)

Return a solver function to solve A x = b using matrix inversion.

Indirect Solvers

torchopt.linear_solve.solve_cg(**kwargs)[source]

Return a solver function to solve A x = b using conjugate gradient.

This assumes that A is a hermitian, positive definite matrix.

Parameters
  • ridge (float or None, optional) – Optional ridge regularization. If provided, solves the equation for A x + ridge x = b. (default: None)

  • init (Tensor, tree of Tensor, or None, optional) – Optional initialization to be used by conjugate gradient. If None, uses zero initialization. (default: None)

  • **kwargs (Any) – Additional keyword arguments for the conjugate gradient solver torchopt.linalg.cg().

Return type

Callable[[Callable[[Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]]

Returns

A solver function with signature (matvec, b) -> x that solves A x = b using conjugate gradient where matvec(v) = A v.

See also

Conjugate gradient iteration torchopt.linalg.cg().

Examples

>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
...     return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']}
>>> b = matvec(x)
>>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)})
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
torchopt.linear_solve.solve_normal_cg(**kwargs)[source]

Return a solver function to solve A^T A x = A^T b using conjugate gradient.

This can be used to solve A x = b using conjugate gradient when A is not hermitian, positive definite.

Parameters
  • ridge (float or None, optional) – Optional ridge regularization. If provided, solves the equation for A^T A x + ridge x = A^T b. (default: None)

  • init (Tensor, tree of Tensor, or None, optional) – Optional initialization to be used by conjugate gradient. If None, uses zero initialization. (default: None)

  • **kwargs (Any) – Additional keyword arguments for the conjugate gradient solver torchopt.linalg.cg().

Return type

Callable[[Callable[[Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]]

Returns

A solver function with signature (matvec, b) -> x that solves A^T A x = A^T b using conjugate gradient where matvec(v) = A v.

See also

Conjugate gradient iteration torchopt.linalg.cg().

Examples

>>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
...     return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']}
>>> b = matvec(x)
>>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)})
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
torchopt.linear_solve.solve_inv(**kwargs)[source]

Return a solver function to solve A x = b using matrix inversion.

If ns = False, this assumes the matrix A is a constant matrix and will materialize it in memory.

Parameters
  • ridge (float or None, optional) – Optional ridge regularization. If provided, solves the equation for A x + ridge x = b. (default: None)

  • ns (bool, optional) – Whether to use Neumann Series matrix inversion approximation. If False, materialize the matrix A in memory and use torch.linalg.solve() instead. (default: False)

  • **kwargs (Any) – Additional keyword arguments for the Neumann Series matrix inversion approximation solver torchopt.linalg.ns().

Return type

Callable[[Callable[[Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]]

Returns

A solver function with signature (matvec, b) -> x that solves A x = b using matrix inversion where matvec(v) = A v.

See also

Neumann Series matrix inversion approximation torchopt.linalg.ns().

Examples

>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)}
>>> x = {'a': torch.randn(5), 'b': torch.randn(3)}
>>> def matvec(x: TensorTree) -> TensorTree:
...     return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']}
>>> b = matvec(x)
>>> solver = solve_inv(ns=True, maxiter=10)
>>> x_hat = solver(matvec, b)
>>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])

Zero-Order Differentiation

zero_order(distribution[, method, argnums, ...])

Return a decorator for applying zero-order differentiation.

nn.ZeroOrderGradientModule(*args, **kwargs)

The base class for zero-order gradient models.

Decorators

torchopt.diff.zero_order.zero_order(distribution, method='naive', argnums=(0,), num_samples=1, sigma=1.0)[source]

Return a decorator for applying zero-order differentiation.

Parameters
  • distribution (callable or Samplable) – A samplable object that has method samplable.sample(sample_shape) or a function that takes the shape as input and returns a shaped batch of samples. This is used to sample perturbations from the given distribution. The distribution should be sphere symmetric.

  • method (str, optional) – The algorithm to use. The currently supported algorithms are 'naive', 'forward', and 'antithetic'. (default: 'naive')

  • argnums (int or tuple of int, optional) – Specifies arguments to compute gradients with respect to. (default: 0)

  • num_samples (int, optional) – The number of sample to get the averaged estimated gradient. (default: 1)

  • sigma (float, optional) – The standard deviation of the perturbation. (default: 1.0)

Return type

Callable[[Callable[..., Tensor]], Callable[..., Tensor]]

Returns

A function decorator that enables zero-order gradient estimation.

Zero-order Gradient Module

class torchopt.diff.zero_order.nn.ZeroOrderGradientModule(*args, **kwargs)[source]

Bases: Module, Samplable

The base class for zero-order gradient models.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod __init_subclass__(method='naive', num_samples=1, sigma=1.0)[source]

Validate and initialize the subclass.

Return type

None

abstract forward(*args, **kwargs)[source]

Do the forward pass of the model.

Return type

Tensor

abstract sample(sample_shape=torch.Size([]))[source]

Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

Return type

Union[Tensor, Sequence[Union[Tensor, float, int, bool]]]

__abstractmethods__ = frozenset({'forward', 'sample'})
__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
__parameters__ = ()
__subclasshook__()

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).


Optimizer Hooks

register_hook(hook)

Stateless identity transformation that leaves input gradients untouched.

zero_nan_hook(g)

Replace nan with zero.

nan_to_num_hook([nan, posinf, neginf])

Return a nan to num hook to replace nan / +inf / -inf with the given numbers.

Hook

torchopt.hook.register_hook(hook)[source]

Stateless identity transformation that leaves input gradients untouched.

This function passes through the gradient updates unchanged.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

torchopt.hook.zero_nan_hook(g)[source]

Replace nan with zero.

Return type

Tensor

torchopt.hook.nan_to_num_hook(nan=0.0, posinf=None, neginf=None)[source]

Return a nan to num hook to replace nan / +inf / -inf with the given numbers.

Return type

Callable[[Tensor], Tensor]


Gradient Transformation

clip_grad_norm(max_norm[, norm_type, ...])

Clip gradient norm of an iterable of parameters.

nan_to_num([nan, posinf, neginf])

Replace updates with values nan / +inf / -inf to the given numbers.

Transforms

torchopt.clip_grad_norm(max_norm, norm_type=2.0, error_if_nonfinite=False)[source]

Clip gradient norm of an iterable of parameters.

Parameters
  • max_norm (float) – The maximum absolute value for each element in the update.

  • norm_type (float, optional) – Type of the used p-norm. Can be 'inf' for infinity norm. (default: 2.0)

  • error_if_nonfinite (bool, optional) – If True, an error is thrown if the total norm of the gradients from updates is nan, inf, or -inf. (default: False)

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

torchopt.nan_to_num(nan=0.0, posinf=None, neginf=None)[source]

Replace updates with values nan / +inf / -inf to the given numbers.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

Optimizer Schedules

linear_schedule(init_value, end_value, ...)

Alias polynomial schedule to linear schedule for convenience.

polynomial_schedule(init_value, end_value, ...)

Construct a schedule with polynomial transition from init to end value.

Schedules

torchopt.schedule.linear_schedule(init_value, end_value, transition_steps, transition_begin=0)[source]

Alias polynomial schedule to linear schedule for convenience.

Return type

Callable[[Union[Tensor, float, int, bool]], Union[Tensor, float, int, bool]]

torchopt.schedule.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]

Construct a schedule with polynomial transition from init to end value.

Parameters
  • init_value (float or Tensor) – Initial value for the scalar to be annealed.

  • end_value (float or Tensor) – End value of the scalar to be annealed.

  • power (float or Tensor) – The power of the polynomial used to transition from init to end.

  • transition_steps (int) – Number of steps over which annealing takes place, the scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin (int, optional) – Must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value). (default: 0)

Returns

A function that maps step counts to values.

Return type

schedule

Apply Parameter Updates

apply_updates(params, updates, *[, inplace])

Apply an update to the corresponding parameters.

Apply Updates

torchopt.apply_updates(params, updates, *, inplace=True)[source]

Apply an update to the corresponding parameters.

This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of GradientTransformations. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using tree_map() (e.g. if you want to manipulate updates in custom ways before applying them).

Parameters
  • params (tree of Tensor) – A tree of parameters.

  • updates (tree of Tensor) – A tree of updates, the tree structure and the shape of the leaf nodes must match that of params.

  • inplace (bool, optional) – If True, will update params in a inplace manner. (default: True)

Return type

Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]

Returns

Updated parameters, with same structure, shape and type as params.

Combining Optimizers

chain(*transformations)

Apply a list of chainable update transformations.

Chain

torchopt.combine.chain(*transformations)[source]

Apply a list of chainable update transformations.

Given a sequence of chainable transforms, chain() returns an init_fn() that constructs a state by concatenating the states of the individual transforms, and returns an update_fn() which chains the update transformations feeding the appropriate state to each.

Parameters

*transformations (iterable of GradientTransformation) – A sequence of chainable (init_fn, update_fn) tuples.

Return type

GradientTransformation

Returns

A single (init_fn, update_fn) tuple.

Distributed Utilities

Initialization and Synchronization

auto_init_rpc([worker_init_fn, ...])

Return a decorator to automatically initialize RPC on the decorated function.

barrier([worker_names])

Synchronize local and remote RPC processes.

torchopt.distributed.auto_init_rpc(worker_init_fn=None, worker_name_format=<function default_worker_name_format>, *, backend=None, rpc_backend_options=None)[source]

Return a decorator to automatically initialize RPC on the decorated function.

Return type

Callable[[TypeVar(F, bound= Callable[..., Any])], TypeVar(F, bound= Callable[..., Any])]

torchopt.distributed.barrier(worker_names=None)[source]

Synchronize local and remote RPC processes.

This will block until all local and remote RPC processes specified under worker_names reach this method to wait for all outstanding work to complete.

Parameters

worker_names (iterable of str or None, optional) – The set of workers to synchronize. If None, all workers. (default: None)

Return type

None

Process group information

get_world_info()

Get the world information.

get_world_rank()

Get the global world rank of the current worker.

get_rank()

Get the global world rank of the current worker.

get_world_size()

Get the world size.

get_local_rank()

Get the local rank of the current worker on the current node.

get_local_world_size()

Get the local world size on the current node.

get_worker_id([id])

Get the worker id from the given id.

torchopt.distributed.get_world_info()[source]

Get the world information.

Return type

WorldInfo

torchopt.distributed.get_world_rank()[source]

Get the global world rank of the current worker.

Return type

int

torchopt.distributed.get_rank()

Get the global world rank of the current worker.

Return type

int

torchopt.distributed.get_world_size()[source]

Get the world size.

Return type

int

torchopt.distributed.get_local_rank()[source]

Get the local rank of the current worker on the current node.

Return type

int

torchopt.distributed.get_local_world_size()[source]

Get the local world size on the current node.

Return type

int

torchopt.distributed.get_worker_id(id=None)[source]

Get the worker id from the given id.

Return type

int

Worker selection

on_rank(*ranks)

Return a decorator to mark a function to be executed only on given ranks.

not_on_rank(*ranks)

Return a decorator to mark a function to be executed only on non given ranks.

rank_zero_only(func)

Return a decorator to mark a function to be executed only on rank zero.

rank_non_zero_only(func)

Return a decorator to mark a function to be executed only on non rank zero.

torchopt.distributed.on_rank(*ranks)[source]

Return a decorator to mark a function to be executed only on given ranks.

Return type

Callable[[TypeVar(F, bound= Callable[..., Any])], TypeVar(F, bound= Callable[..., Any])]

torchopt.distributed.not_on_rank(*ranks)[source]

Return a decorator to mark a function to be executed only on non given ranks.

Return type

Callable[[TypeVar(F, bound= Callable[..., Any])], TypeVar(F, bound= Callable[..., Any])]

torchopt.distributed.rank_zero_only(func)[source]

Return a decorator to mark a function to be executed only on rank zero.

Return type

TypeVar(F, bound= Callable[..., Any])

torchopt.distributed.rank_non_zero_only(func)[source]

Return a decorator to mark a function to be executed only on non rank zero.

Return type

TypeVar(F, bound= Callable[..., Any])

Remote Procedure Call (RPC)

remote_async_call(func, *[, args, kwargs, ...])

Asynchronously do an RPC on remote workers and return the a torch.Future instance at the current worker.

remote_sync_call(func, *[, args, kwargs, ...])

Do an RPC synchronously on remote workers and return the result to the current worker.

torchopt.distributed.remote_async_call(func, *, args=None, kwargs=None, partitioner=None, reducer=None, timeout=-1.0)[source]

Asynchronously do an RPC on remote workers and return the a torch.Future instance at the current worker.

Parameters
  • func (callable) – The function to call.

  • args (tuple of object or None, optional) – The arguments to pass to the function. (default: None)

  • kwargs (dict[str, object] or None, optional) – The keyword arguments to pass to the function. (default: None)

  • partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default: batch_partitioner())

  • reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If None, do not reduce the results. (default: None)

  • timeout (float, optional) – The timeout for the RPC call. (default: rpc.api.UNSET_RPC_TIMEOUT)

Return type

Union[Future[list[TypeVar(T)]], Future[TypeVar(U)]]

Returns

A torch.Future instance for the result. The result is at the current worker.

torchopt.distributed.remote_sync_call(func, *, args=None, kwargs=None, partitioner=None, reducer=None, timeout=-1.0)[source]

Do an RPC synchronously on remote workers and return the result to the current worker.

Parameters
  • func (callable) – The function to call.

  • args (tuple of object or None, optional) – The arguments to pass to the function. (default: None)

  • kwargs (dict[str, object] or None, optional) – The keyword arguments to pass to the function. (default: None)

  • partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default: batch_partitioner())

  • reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If None, do not reduce the results. (default: None)

  • timeout (float, optional) – The timeout for the RPC call. (default: rpc.api.UNSET_RPC_TIMEOUT)

Return type

Union[list[TypeVar(T)], TypeVar(U)]

Returns

The result of the RPC call. The result is at the current worker.

Predefined partitioners and reducers

dim_partitioner([dim, exclusive, keepdim, ...])

Partition a batch of inputs along a given dimension.

batch_partitioner

Partitioner class that partitions a batch of inputs along a given dimension.

mean_reducer(results)

Reduce the results by averaging them.

sum_reducer(results)

Reduce the results by summing them.

torchopt.distributed.dim_partitioner(dim=0, *, exclusive=False, keepdim=True, workers=None)[source]

Partition a batch of inputs along a given dimension.

All tensors in the args and kwargs will be partitioned along the dimension dim, while the non-tensor values will be broadcasted to partitions.

Parameters
  • dim (int, optional) – The dimension to partition. (default: 0)

  • exclusive (bool, optional) – Whether to partition the batch exclusively. (default: False) If True, the batch will be partitioned into batch_size partitions, where batch_size is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. If False, the batch will be partitioned into min(batch_size, num_workers) partitions, where num_workers is the number of workers in the world. When batch_size > num_workers, there can be multiple batch samples forward in a single RPC call.

  • keepdim (bool, optional) – Whether to keep the partitioned dimension. (default: False) If True, keep the batch dimension. If False, use select instead of slicing. This functionality should be used with exclusive=True.

  • workers (sequence of int or str, or None, optional) – The workers to partition the batch to. If None, the batch will be partitioned to all workers in the world. (default: None)

Return type

Callable[..., Sequence[Tuple[int, Optional[Tuple[Any, ...]], Optional[Dict[str, Any]]]]]

Returns

A partition function.

torchopt.distributed.batch_partitioner(*args: Any, **kwargs: Any) list[tuple[int, Args | None, KwArgs | None]]

Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.

The batch will be partitioned into min(batch_size, num_workers) partitions, where num_workers is the number of workers in the world. When batch_size > num_workers, there can be multiple batch samples forward in a single RPC call.

All tensors in the args and kwargs will be partitioned along the dimension dim, while the non-tensor values will be broadcasted to partitions.

torchopt.distributed.mean_reducer(results)[source]

Reduce the results by averaging them.

Return type

Tensor

torchopt.distributed.sum_reducer(results)[source]

Reduce the results by summing them.

Return type

Tensor

Function parallelization wrappers

parallelize([partitioner, reducer, timeout])

Return a decorator for parallelizing a function.

parallelize_async([partitioner, reducer, ...])

Return a decorator for parallelizing a function.

parallelize_sync([partitioner, reducer, timeout])

Return a decorator for parallelizing a function.

torchopt.distributed.parallelize(partitioner=None, reducer=None, timeout=-1.0)[source]

Return a decorator for parallelizing a function.

This decorator can be used to parallelize a function call across multiple workers.

Parameters
  • partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default: batch_partitioner())

  • reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If None, do not reduce the results. (default: None)

  • timeout (float, optional) – The timeout for the RPC call. (default: rpc.api.UNSET_RPC_TIMEOUT)

Return type

Callable[[Callable[..., TypeVar(T)]], Callable[..., Union[list[TypeVar(T)], TypeVar(U)]]]

Returns

The decorator function.

torchopt.distributed.parallelize_async(partitioner=None, reducer=None, timeout=-1.0)[source]

Return a decorator for parallelizing a function.

This decorator can be used to parallelize a function call across multiple workers. The function will be called asynchronously on remote workers. The decorated function will return a torch.Future instance of the result.

Parameters
  • partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default: batch_partitioner())

  • reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If None, do not reduce the results. (default: None)

  • timeout (float, optional) – The timeout for the RPC call. (default: rpc.api.UNSET_RPC_TIMEOUT)

Return type

Callable[[Callable[..., TypeVar(T)]], Callable[..., Union[Future[list[TypeVar(T)]], Future[TypeVar(U)]]]]

Returns

The decorator function.

torchopt.distributed.parallelize_sync(partitioner=None, reducer=None, timeout=-1.0)

Return a decorator for parallelizing a function.

This decorator can be used to parallelize a function call across multiple workers.

Parameters
  • partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default: batch_partitioner())

  • reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If None, do not reduce the results. (default: None)

  • timeout (float, optional) – The timeout for the RPC call. (default: rpc.api.UNSET_RPC_TIMEOUT)

Return type

Callable[[Callable[..., TypeVar(T)]], Callable[..., Union[list[TypeVar(T)], TypeVar(U)]]]

Returns

The decorator function.

Distributed Autograd

context()

Context object to wrap forward and backward passes when using distributed autograd.

get_gradients(context_id)

Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context corresponding to the given context_id as part of the distributed autograd backward pass.

backward(autograd_ctx_id, tensors[, ...])

Perform distributed backward pass for local parameters.

grad(autograd_ctx_id, outputs, inputs[, ...])

Compute and return the sum of gradients of outputs with respect to the inputs.

torchopt.distributed.autograd.context()[source]

Context object to wrap forward and backward passes when using distributed autograd. The context_id generated in the with statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with this context_id, which is required to correctly execute a distributed autograd pass.

Example::
>>> # xdoctest: +SKIP
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>     t1 = torch.rand((3, 3), requires_grad=True)
>>>     t2 = torch.rand((3, 3), requires_grad=True)
>>>     loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
>>>     dist_autograd.backward(context_id, [loss])
torchopt.distributed.autograd.get_gradients(context_id: int) Dict[Tensor, Tensor]

Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context corresponding to the given context_id as part of the distributed autograd backward pass.

Parameters

context_id (int) – The autograd context id for which we should retrieve the gradients.

Returns

A map where the key is the Tensor and the value is the associated gradient for that Tensor.

Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>     t1 = torch.rand((3, 3), requires_grad=True)
>>>     t2 = torch.rand((3, 3), requires_grad=True)
>>>     loss = t1 + t2
>>>     dist_autograd.backward(context_id, [loss.sum()])
>>>     grads = dist_autograd.get_gradients(context_id)
>>>     print(grads[t1])
>>>     print(grads[t2])
torchopt.distributed.autograd.backward(autograd_ctx_id, tensors, retain_graph=False, inputs=None)[source]

Perform distributed backward pass for local parameters.

Compute the sum of gradients of given tensors with respect to graph leaves.

Parameters
  • autograd_ctx_id (int) – The autograd context id.

  • tensors (Tensor or sequence of Tensor) – Tensors of which the derivative will be computed.

  • retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. (default: False)

  • inputs (Tensor, sequence of Tensor, or None, optional) – Inputs w.r.t. which the gradient be will accumulated into .grad. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute the tensors. (default: None)

Return type

None

torchopt.distributed.autograd.grad(autograd_ctx_id, outputs, inputs, retain_graph=False, allow_unused=False)[source]

Compute and return the sum of gradients of outputs with respect to the inputs.

Parameters
  • autograd_ctx_id (int) – The autograd context id.

  • outputs (Tensor or sequence of Tensor) – Outputs of the differentiated function.

  • inputs (Tensor or sequence of Tensor) – Inputs w.r.t. which the gradient will be returned (and not accumulated into .grad).

  • retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. (default: False)

  • allow_unused (bool, optional) – If False, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. (default: False)

Return type

Tuple[Optional[Tensor], ...]

General Utilities

extract_state_dict(target, *[, by, device, ...])

Extract target state.

recover_state_dict(target, state)

Recover state.

stop_gradient(target)

Stop the gradient for the input object.

Extract State Dict

torchopt.extract_state_dict(target, *, by='reference', device=None, with_buffers=True, detach_buffers=False, enable_visual=False, visual_prefix='')[source]

Extract target state.

Since a tensor use grad_fn to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is connected by grad_fn. Some algorithms requires manually detaching tensors from the computation graph.

Note that the extracted state is a reference, which means any in-place operator will affect the target that the state is extracted from.

Parameters
  • target (nn.Module or MetaOptimizer) – It could be a nn.Module or torchopt.MetaOptimizer.

  • by (str, optional) – The extract policy of tensors in the target. (default: 'reference') - 'reference': The extracted tensors will be references to the original tensors. - 'copy': The extracted tensors will be clones of the original tensors. This makes the copied tensors have grad_fn to be a <CloneBackward> function points to the original tensors. - 'deepcopy': The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph.

  • device (Device or None, optional) – If specified, move the extracted state to the specified device. (default: None)

  • with_buffers (bool, optional) – Extract buffer together with parameters, this argument is only used if the input target is nn.Module. (default: True)

  • detach_buffers (bool, optional) – Whether to detach the reference to the buffers, this argument is only used if the input target is nn.Module and by='reference'. (default: False)

  • enable_visual (bool, optional) – Add additional annotations, which could be used in computation graph visualization. Currently, this flag only has effect on nn.Module but we will support torchopt.MetaOptimizer later. (default: False)

  • visual_prefix (str, optional) – Prefix for the visualization annotations. (default: '')

Return type

ModuleState | tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]], ...]

Returns

State extracted of the input object.

Recover State Dict

torchopt.recover_state_dict(target, state)[source]

Recover state.

This function is compatible for the extract_state.

Note that the recovering process is not in-place, so the tensors of the object will not be modified.

Parameters
  • target (nn.Module or MetaOptimizer) – Target that need to recover.

  • state (ModuleState or sequence of tree of Tensor) – The recovering state.

Return type

None

Stop Gradient

torchopt.stop_gradient(target)[source]

Stop the gradient for the input object.

Since a tensor use grad_fn to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is connected by grad_fn. Some algorithms requires manually detaching tensors from the computation graph.

Note that the stop_gradient() operation is in-place.

Parameters

target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor) – The target that to be detached from the computation graph, it could be a nn.Module, torchopt.MetaOptimizer, state of the torchopt.MetaOptimizer, or just a plain list of tensors.

Return type

None

Visualizing Gradient Flow

make_dot(var[, params, show_attrs, ...])

Produce Graphviz representation of PyTorch autograd graph.

Make Dot

torchopt.visual.make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50)[source]

Produce Graphviz representation of PyTorch autograd graph.

If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green:

  • Blue

    Reachable leaf tensors that requires grad (tensors whose grad fields will be populated during backward()).

  • Orange

    Saved tensors of custom autograd functions as well as those saved by built-in backward nodes.

  • Green

    Tensor passed in as outputs.

  • Dark green

    If any output is a view, we represent its base tensor with a dark green node.

Parameters
  • var (Tensor or sequence of Tensor) – Output tensor.

  • params (Optional[Union[Mapping[str, Tensor], ModuleState, Generator, Iterable[Union[Mapping[str, Tensor], ModuleState, Generator]]]]) – (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional): Parameters to add names to node that requires grad. (default: None)

  • show_attrs (bool, optional) – Whether to display non-tensor attributes of backward nodes. (default: False)

  • show_saved (bool, optional) – Whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed. (default: False)

  • max_attr_chars (int, optional) – If show_attrs is True, sets max number of characters to display for any given attribute. (default: 50)

Return type

Digraph