TorchOpt Optimizer
|
A base class for classic optimizers that similar to |
|
The base class for high-level differentiable optimizers. |
Optimizer
- class torchopt.Optimizer(params, impl)[source]
Bases:
object
A base class for classic optimizers that similar to
torch.optim.Optimizer
.Initialize the optimizer.
- Parameters:
params (iterable of torch.Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in
torchopt.alias
or a customizedtorchopt.chain()
ed transformation. Note that usingOptimizer(sgd())
orOptimizer(chain(sgd()))
is equivalent totorchopt.SGD
.
- __init__(params, impl)[source]
Initialize the optimizer.
- Parameters:
params (iterable of torch.Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in
torchopt.alias
or a customizedtorchopt.chain()
ed transformation. Note that usingOptimizer(sgd())
orOptimizer(chain(sgd()))
is equivalent totorchopt.SGD
.
- zero_grad(set_to_none=False)[source]
Set the gradients of all optimized
torch.Tensor
s to zero.The behavior is similar to
torch.optim.Optimizer.zero_grad()
.
- load_state_dict(state_dict)[source]
Load the optimizer state.
- Parameters:
state_dict (sequence of tree of Tensor) – Optimizer state. Should be an object returned from a call to
state_dict()
.- Return type:
None
MetaOptimizer
- class torchopt.MetaOptimizer(module, impl)[source]
Bases:
object
The base class for high-level differentiable optimizers.
Initialize the meta-optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in
torchopt.alias
or a customizedtorchopt.chain()
ed transformation. Note that usingMetaOptimizer(sgd(moment_requires_grad=True))
orMetaOptimizer(chain(sgd(moment_requires_grad=True)))
is equivalent totorchopt.MetaSGD
.
- __init__(module, impl)[source]
Initialize the meta-optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
impl (GradientTransformation) – A low level optimizer function, it could be a optimizer function provided in
torchopt.alias
or a customizedtorchopt.chain()
ed transformation. Note that usingMetaOptimizer(sgd(moment_requires_grad=True))
orMetaOptimizer(chain(sgd(moment_requires_grad=True)))
is equivalent totorchopt.MetaSGD
.
- step(loss)[source]
Compute the gradients of the loss to the network parameters and update network parameters.
Graph of the derivative will be constructed, allowing to compute higher order derivative products. We use the differentiable optimizer (pass argument
inplace=False
) to scale the gradients and update the network parameters without modifying tensors in-place.- Parameters:
loss (torch.Tensor) – The loss that is used to compute the gradients to the network parameters.
- Return type:
- state_dict()[source]
Extract the references of the optimizer states.
Note that the states are references, so any in-place operations will change the states inside
MetaOptimizer
at the same time.- Return type:
tuple[OptState, …]
Functional Optimizers
|
A wrapper class to hold the functional optimizer. |
|
Create a functional version of the AdaDelta optimizer. |
|
Create a functional version of the AdaGrad optimizer. |
|
Create a functional version of the Adam optimizer. |
|
Create a functional version of the Adam optimizer with weight decay regularization. |
|
Create a functional version of the AdaMax optimizer. |
|
Create a functional version of the RAdam optimizer. |
|
Create a functional version of the RMSProp optimizer. |
|
Create a functional version of the canonical Stochastic Gradient Descent optimizer. |
Wrapper for Function Optimizer
- class torchopt.FuncOptimizer(impl, *, inplace=False)[source]
Bases:
object
A wrapper class to hold the functional optimizer.
This wrapper makes it easier to maintain the optimizer states. The optimizer states are held by the wrapper internally. The wrapper provides a
step()
function to compute the gradients and update the parameters.See also
The functional AdaDelta optimizer:
torchopt.adadelta()
.The functional AdaGrad optimizer:
torchopt.adagrad()
.The functional Adam optimizer:
torchopt.adam()
.The functional AdamW optimizer:
torchopt.adamw()
.The functional AdaMax optimizer:
torchopt.adamax()
.The functional RAdam optimizer:
torchopt.radam()
.The functional RMSprop optimizer:
torchopt.rmsprop()
.The functional SGD optimizer:
torchopt.sgd()
.
Initialize the functional optimizer wrapper.
- Parameters:
- step(loss, params, inplace=None)[source]
Compute the gradients of loss to the network parameters and update network parameters.
Graph of the derivative will be constructed, allowing to compute higher order derivative products. We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update the network parameters without modifying tensors in-place.
- Parameters:
loss (Tensor) – The loss that is used to compute the gradients to network parameters.
params (tree of Tensor) – An tree of
torch.Tensor
s. Specifies what tensors should be optimized.inplace (bool or None, optional) – Whether to update the parameters in-place. If
None
, use the default value specified in the constructor. (default:None
)
- Return type:
Params
- state_dict()[source]
Extract the references of the optimizer states.
Note that the states are references, so any in-place operations will change the states inside
FuncOptimizer
at the same time.- Return type:
OptState
Functional AdaDelta Optimizer
- torchopt.adadelta(lr=0.001, rho=0.9, eps=1e-06, weight_decay=0.0, *, moment_requires_grad=False)[source]
Create a functional version of the AdaDelta optimizer.
Adadelta is a per-dimension learning rate method for gradient descent.
References
Zeiler, 2012: https://arxiv.org/abs/1212.5701
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)rho (float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
0.9
)eps (float, optional) – A small constant applied to the square root (as in the Adadelta paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional AdaGrad Optimizer
- torchopt.adagrad(lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, *, maximize=False)[source]
Create a functional version of the AdaGrad optimizer.
AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training.
Warning
AdaGrad’s main limit is the monotonic accumulation of squared gradients in the denominator. Since all terms are
> 0
, the sum keeps growing during training, and the learning rate eventually becomes very small.References
Duchi et al., 2011: https://jmlr.org/papers/v12/duchi11a.html
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-2
)lr_decay (float, optional) – Learning rate decay. (default:
0.0
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)initial_accumulator_value (float, optional) – Initial value for the accumulator. (default:
0.0
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-10
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional Adam Optimizer
- torchopt.adam(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]
Create a functional version of the Adam optimizer.
Adam is an SGD variant with learning rate adaptation. The learning rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages).
References
Kingma et al., 2014: https://arxiv.org/abs/1412.6980
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)use_accelerated_op (bool, optional) – If
True
use our implemented fused operator. (default:False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional AdamW Optimizer
- torchopt.adamw(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]
Create a functional version of the Adam optimizer with weight decay regularization.
AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.
References
Loshchilov et al., 2019: https://arxiv.org/abs/1711.05101
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate. (default:
1e-2
)eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default:
0.0
)mask (tree of Tensor, callable, or None, optional) – A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans,
True
for leaves/subtrees you want to apply the weight decay to, andFalse
for those you want to skip. Note that the Adam gradient transformations are applied to all parameters. (default:None
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)use_accelerated_op (bool, optional) – If
True
use our implemented fused operator. (default:False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional AdaMax Optimizer
- torchopt.adamax(lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, *, moment_requires_grad=False)[source]
Create a functional version of the AdaMax optimizer.
References
Kingma et al., 2014: https://arxiv.org/abs/1412.6980
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional RAdam Optimizer
- torchopt.radam(lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, *, moment_requires_grad=False)[source]
Create a functional version of the RAdam optimizer.
RAdam is a variance of the adaptive learning rate rectified optimizer.
References
Liu, 2019: https://arxiv.org/abs/1908.03265
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional RMSProp Optimizer
- torchopt.rmsprop(lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]
Create a functional version of the RMSProp optimizer.
RMSProp is an SGD variant with learning rate adaptation. The learning rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.
References
Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
Graves, 2013: https://arxiv.org/abs/1308.0850
- Parameters:
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-2
)alpha (float, optional) – Smoothing constant, the decay used to track the magnitude of previous gradients. (default:
0.99
)eps (float, optional) – A small numerical constant to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to
0.0
. (default:0.0
)centered (bool, optional) – If
True
, use the variance of the past gradients to rescale the latest gradients. (default:False
)initial_scale (float, optional) – Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses
0.0
, TensorFlow 1.x uses1.0
. When reproducing results from a paper, verify the value used by the authors. (default:0.0
)nesterov (bool, optional) – Whether to use Nesterov momentum. (default:
False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Functional SGD Optimizer
- torchopt.sgd(lr, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, *, moment_requires_grad=False, maximize=False)[source]
Create a functional version of the canonical Stochastic Gradient Descent optimizer.
This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.
References
Sutskever et al., 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
- Parameters:
lr (float or callable) – This is a fixed global scaling factor or a learning rate scheduler.
momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to
0.0
. (default:0.0
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)dampening (float, optional) – Dampening for momentum. (default:
0.0
)nesterov (bool, optional) – Whether to use Nesterov momentum. (default:
False
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
- Return type:
GradientTransformation
- Returns:
The corresponding
GradientTransformation
instance.
See also
The functional optimizer wrapper
torchopt.FuncOptimizer
.
Classic Optimizers
|
The classic AdaDelta optimizer. |
|
alias of |
|
The classic AdaGrad optimizer. |
|
alias of |
|
The classic Adam optimizer. |
|
The classic AdamW optimizer. |
|
The classic AdaMax optimizer. |
|
alias of |
|
The classic RAdam optimizer. |
|
The classic RMSProp optimizer. |
|
The classic SGD optimizer. |
Classic AdaDelta Optimizer
- class torchopt.AdaDelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0.0)[source]
Bases:
Optimizer
The classic AdaDelta optimizer.
See also
The functional AdaDelta optimizer:
torchopt.adadelta()
.The differentiable meta-AdaDelta optimizer:
torchopt.MetaAdaDetla
.
Initialize the AdaDelta optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)rho (float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
0.9
)eps (float, optional) – A small constant applied to the square root (as in the AdaDelta paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)
Classic AdaGrad Optimizer
- class torchopt.AdaGrad(params, lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, *, maximize=False)[source]
Bases:
Optimizer
The classic AdaGrad optimizer.
See also
The functional AdaGrad optimizer:
torchopt.adagrad()
.The differentiable meta-AdaGrad optimizer:
torchopt.MetaAdaGrad
.
Initialize the AdaGrad optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-2
)lr_decay (float, optional) – Learning rate decay. (default:
0.0
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)initial_accumulator_value (float, optional) – Initial value for the accumulator. (default:
0.0
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-10
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
Classic Adam Optimizer
- class torchopt.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, maximize=False, use_accelerated_op=False)[source]
Bases:
Optimizer
The classic Adam optimizer.
See also
The functional Adam optimizer:
torchopt.adam()
.The differentiable meta-Adam optimizer:
torchopt.MetaAdam
.
Initialize the Adam optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)use_accelerated_op (bool, optional) – If
True
use our implemented fused operator. (default:False
)
Classic AdamW Optimizer
- class torchopt.AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, maximize=False, use_accelerated_op=False)[source]
Bases:
Optimizer
The classic AdamW optimizer.
See also
The functional AdamW optimizer:
torchopt.adamw()
.The differentiable meta-AdamW optimizer:
torchopt.MetaAdamW
.
Initialize the AdamW optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate. (default:
1e-2
)eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default:
0.0
)mask (tree of Tensor, callable, or None, optional) – A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans,
True
for leaves/subtrees you want to apply the weight decay to, andFalse
for those you want to skip. Note that the Adam gradient transformations are applied to all parameters. (default:None
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)use_accelerated_op (bool, optional) – If
True
use our implemented fused operator. (default:False
)
Classic AdaMax Optimizer
- class torchopt.AdaMax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)[source]
Bases:
Optimizer
The classic AdaMax optimizer.
See also
The functional AdaMax optimizer:
torchopt.adamax()
.The differentiable meta-AdaMax optimizer:
torchopt.MetaAdaMax
.
Initialize the AdaMax optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to the square root (as in the AdaMax paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)
Classic RAdam Optimizer
- class torchopt.RAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)[source]
Bases:
Optimizer
The classic RAdam optimizer.
See also
The functional Adam optimizer:
torchopt.radam()
.The differentiable meta-RAdam optimizer:
torchopt.MetaRAdam
.
Initialize the RAdam optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)
Classic RMSProp Optimizer
- class torchopt.RMSProp(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]
Bases:
Optimizer
The classic RMSProp optimizer.
See also
The functional RMSProp optimizer:
torchopt.rmsprop()
.The differentiable meta-RMSProp optimizer:
torchopt.MetaRMSProp
.
Initialize the RMSProp optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-2
)alpha (float, optional) – Smoothing constant, the decay used to track the magnitude of previous gradients. (default:
0.99
)eps (float, optional) – A small numerical constant to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to
0.0
. (default:0.0
)centered (bool, optional) – If
True
, use the variance of the past gradients to rescale the latest gradients. (default:False
)initial_scale (float, optional) – Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses
0.0
, TensorFlow 1.x uses1.0
. When reproducing results from a paper, verify the value used by the authors. (default:0.0
)nesterov (bool, optional) – Whether to use Nesterov momentum. (default:
False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
Classic SGD Optimizer
- class torchopt.SGD(params, lr, momentum=0.0, weight_decay=0.0, dampening=0.0, nesterov=False, maximize=False)[source]
Bases:
Optimizer
The classic SGD optimizer.
See also
The functional SGD optimizer:
torchopt.sgd()
.The differentiable meta-SGD optimizer:
torchopt.MetaSGD
.
Initialize the SGD optimizer.
- Parameters:
params (iterable of Tensor) – An iterable of
torch.Tensor
s. Specifies what tensors should be optimized.lr (float or callable) – This is a fixed global scaling factor or a learning rate scheduler.
momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to
0.0
. (default:0.0
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)dampening (float, optional) – Dampening for momentum. (default:
0.0
)nesterov (bool, optional) – Whether to use Nesterov momentum. (default:
False
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
Differentiable Meta-Optimizers
|
The differentiable AdaDelta optimizer. |
|
alias of |
|
The differentiable AdaGrad optimizer. |
|
alias of |
|
The differentiable Adam optimizer. |
|
The differentiable AdamW optimizer. |
|
The differentiable AdaMax optimizer. |
|
alias of |
|
The differentiable RAdam optimizer. |
|
The differentiable RMSProp optimizer. |
|
The differentiable Stochastic Gradient Descent optimizer. |
Differentiable Meta-AdaDelta Optimizer
- class torchopt.MetaAdaDelta(module, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0.0, *, moment_requires_grad=True)[source]
Bases:
MetaOptimizer
The differentiable AdaDelta optimizer.
See also
The functional AdaDelta optimizer:
torchopt.adadetla()
.The classic AdaDelta optimizer:
torchopt.Adadelta
.
Initialize the meta AdaDelta optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)rho (float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
0.9
)eps (float, optional) – A small constant applied to the square root (as in the AdaDelta paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)
Differentiable Meta-AdaGrad Optimizer
- class torchopt.MetaAdaGrad(module, lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, *, maximize=False)[source]
Bases:
MetaOptimizer
The differentiable AdaGrad optimizer.
See also
The functional AdaGrad optimizer:
torchopt.adagrad()
.The classic AdaGrad optimizer:
torchopt.Adagrad
.
Initialize the meta AdaGrad optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-2
)lr_decay (float, optional) – Learning rate decay. (default:
0.0
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)initial_accumulator_value (float, optional) – Initial value for the accumulator. (default:
0.0
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-10
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
Differentiable Meta-Adam Optimizer
- class torchopt.MetaAdam(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, eps_root=0.0, moment_requires_grad=True, maximize=False, use_accelerated_op=False)[source]
Bases:
MetaOptimizer
The differentiable Adam optimizer.
See also
The functional Adam optimizer:
torchopt.adam()
.The classic Adam optimizer:
torchopt.Adam
.
Initialize the meta-Adam optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)use_accelerated_op (bool, optional) – If
True
use our implemented fused operator. (default:False
)
Differentiable Meta-AdamW Optimizer
- class torchopt.MetaAdamW(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, *, eps_root=0.0, mask=None, moment_requires_grad=False, maximize=False, use_accelerated_op=False)[source]
Bases:
MetaOptimizer
The differentiable AdamW optimizer.
See also
The functional AdamW optimizer:
torchopt.adamw()
.The classic AdamW optimizer:
torchopt.AdamW
.
Initialize the meta-AdamW optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate. (default:
1e-2
)eps_root (float, optional) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. (default:
0.0
)mask (tree of Tensor, callable, or None, optional) – A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans,
True
for leaves/subtrees you want to apply the weight decay to, andFalse
for those you want to skip. Note that the Adam gradient transformations are applied to all parameters. (default:None
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)use_accelerated_op (bool, optional) – If
True
use our implemented fused operator. (default:False
)
Differentiable Meta-AdaMax Optimizer
- class torchopt.MetaAdaMax(module, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, moment_requires_grad=True)[source]
Bases:
MetaOptimizer
The differentiable AdaMax optimizer.
See also
The functional AdaMax optimizer:
torchopt.adamax()
.The classic AdaMax optimizer:
torchopt.Adamax
.
Initialize the meta AdaMax optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to the square root (as in the AdaMax paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)
Differentiable Meta-RAdam Optimizer
- class torchopt.MetaRAdam(module, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, moment_requires_grad=True)[source]
Bases:
MetaOptimizer
The differentiable RAdam optimizer.
See also
The functional RAdam optimizer:
torchopt.radan()
.The classic RAdam optimizer:
torchopt.RAdam
.
Initialize the meta-RAdam optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-3
)betas (tuple of float, optional) – Coefficients used for computing running averages of gradient and its square. (default:
(0.9, 0.999)
)eps (float, optional) – A small constant applied to the square root (as in the RAdam paper) to avoid dividing by zero when rescaling. (default:
1e-6
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)
Differentiable Meta-RMSProp Optimizer
- class torchopt.MetaRMSProp(module, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False, *, initial_scale=0.0, nesterov=False, maximize=False)[source]
Bases:
MetaOptimizer
The differentiable RMSProp optimizer.
See also
The functional RMSProp optimizer:
torchopt.rmsprop()
.The classic RMSProp optimizer:
torchopt.RMSProp
.
Initialize the meta-RMSProp optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable, optional) – This is a fixed global scaling factor or a learning rate scheduler. (default:
1e-2
)alpha (float, optional) – Smoothing constant, the decay used to track the magnitude of previous gradients. (default:
0.99
)eps (float, optional) – A small numerical constant to avoid dividing by zero when rescaling. (default:
1e-8
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to
0.0
. (default:0.0
)centered (bool, optional) – If
True
, use the variance of the past gradients to rescale the latest gradients. (default:False
)initial_scale (float, optional) – Initialization of accumulators tracking the magnitude of previous updates. PyTorch uses
0.0
, TensorFlow 1.x uses1.0
. When reproducing results from a paper, verify the value used by the authors. (default:0.0
)nesterov (bool, optional) – Whether to use Nesterov momentum. (default:
False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
Differentiable Meta-SGD Optimizer
- class torchopt.MetaSGD(module, lr, momentum=0.0, weight_decay=0.0, dampening=0.0, nesterov=False, moment_requires_grad=True, maximize=False)[source]
Bases:
MetaOptimizer
The differentiable Stochastic Gradient Descent optimizer.
See also
The functional SGD optimizer:
torchopt.sgd()
.The classic SGD optimizer:
torchopt.SGD
.
Initialize the meta-SGD optimizer.
- Parameters:
module (nn.Module) – A network whose parameters should be optimized.
lr (float or callable) – This is a fixed global scaling factor or a learning rate scheduler.
momentum (float, optional) – The decay rate used by the momentum term. The momentum is not used when it is set to
0.0
. (default:0.0
)weight_decay (float, optional) – Weight decay, add L2 penalty to parameters. (default:
0.0
)dampening (float, optional) – Dampening for momentum. (default:
0.0
)nesterov (bool, optional) – Whether to use Nesterov momentum. (default:
False
)moment_requires_grad (bool, optional) – If
True
the momentums will be created with flagrequires_grad=True
, this flag is often used in Meta-Learning algorithms. (default:False
)maximize (bool, optional) – Maximize the params based on the objective, instead of minimizing. (default:
False
)
Implicit Differentiation
|
Return a decorator for adding implicit differentiation to a root solver. |
|
The base class for differentiable implicit meta-gradient models. |
Custom Solvers
- torchopt.diff.implicit.custom_root(optimality_fn, argnums, has_aux=False, solve=None)[source]
Return a decorator for adding implicit differentiation to a root solver.
This wrapper should be used as a decorator:
def optimality_fn(optimal_params, ...): ... return residual @custom_root(optimality_fn, argnums=argnums) def solver_fn(params, arg1, arg2, ...): ... return optimal_params optimal_params = solver_fn(init_params, ...)
The first argument to
optimality_fn
andsolver_fn
is preserved as the parameter input. Theargnums
argument refers to the indices of the variables insolver_fn
’s signature. For example, settingargnums=(1, 2)
will compute the gradient ofoptimal_params
with respect toarg1
andarg2
in the above snippet. Note that, except the first argument, the keyword arguments of theoptimality_fn
should be a subset of the ones ofsolver_fn
. In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.- Parameters:
optimality_fn (callable) – An equation function,
optimality_fn(params, *args)
. The invariant isoptimality_fn(solution, *args) == 0
at the solution / root ofsolution
.argnums (int or tuple of int) – Specifies arguments to compute gradients with respect to. The
argnums
can be an integer or a tuple of integers, which respect to the zero-based indices of the arguments of thesolver_fn(params, *args)
function. The argumentparams
is included for the counting, while it is indexed asargnums=0
.has_aux (bool, optional) – Whether the decorated solver function returns auxiliary data. (default:
False
)solve (callable, optional) – A linear solver of the form
solve(matvec, b)
. (default:linear_solve.solve_normal_cg()
)
- Return type:
Callable
[[Callable
[...
,Union
[Tensor
,Sequence
[Tensor
],tuple
[Union
[Tensor
,Sequence
[Tensor
]],Any
]]]],Callable
[...
,Union
[Tensor
,Sequence
[Tensor
],tuple
[Union
[Tensor
,Sequence
[Tensor
]],Any
]]]]- Returns:
A solver function decorator, i.e.,
custom_root(optimality_fn)(solver_fn)
.
Implicit Meta-Gradient Module
- class torchopt.diff.implicit.nn.ImplicitMetaGradientModule(*args, **kwargs)[source]
Bases:
MetaGradientModule
The base class for differentiable implicit meta-gradient models.
Initialize a new module instance.
- linear_solve: LinearSolver | None
- classmethod __init_subclass__(linear_solve=None)[source]
Validate and initialize the subclass.
- Return type:
None
- abstract solve(*input, **kwargs)[source]
Solve the inner optimization problem. :rtype:
Any
Warning
For gradient-based optimization methods, the parameter inputs should be explicitly specified in the
torch.autograd.backward()
function as argumentinputs
. Otherwise, if not provided, the gradient is accumulated into all the leaf Tensors (including the meta-parameters) that were used to compute the objective output. Alternatively, please usetorch.autograd.grad()
instead.Examples
def solve(self, batch, labels): parameters = tuple(self.parameters()) optimizer = torch.optim.Adam(parameters, lr=1e-3) with torch.enable_grad(): for _ in range(100): loss = self.objective(batch, labels) optimizer.zero_grad() # Only update the `.grad` attribute for parameters # and leave the meta-parameters unchanged loss.backward(inputs=parameters) optimizer.step() return self
- optimality(*input, **kwargs)[source]
Compute the optimality residual.
This method stands for the optimality residual to the optimal parameters after solving the inner optimization problem (
solve()
), i.e.:module.solve(*input, **kwargs) module.optimality(*input, **kwargs) # -> 0
1. For gradient-based optimization, the
optimality()
function is the KKT condition, usually it is the gradients of theobjective()
function with respect to the module parameters (not the meta-parameters). If this method is not implemented, it will be automatically derived from the gradient of theobjective()
function.\[\text{optimality residual} = \nabla_{\boldsymbol{x}} f (\boldsymbol{x}, \boldsymbol{\theta}) \to \boldsymbol{0}\]where \(\boldsymbol{x}\) is the joint vector of the module parameters and \(\boldsymbol{\theta}\) is the joint vector of the meta-parameters.
References
Karush-Kuhn-Tucker (KKT) conditions: https://en.wikipedia.org/wiki/Karush-Kuhn-Tucker_conditions
2. For fixed point iteration, the
optimality()
function can be the residual of the parameters between iterations, i.e.:\[\text{optimality residual} = f (\boldsymbol{x}, \boldsymbol{\theta}) - \boldsymbol{x} \to \boldsymbol{0}\]where \(\boldsymbol{x}\) is the joint vector of the module parameters and \(\boldsymbol{\theta}\) is the joint vector of the meta-parameters.
- objective(*input, **kwargs)[source]
Compute the objective function value.
This method is used to calculate the
optimality()
if it is not implemented. Otherwise, this method is optional.- Return type:
Tensor
- Returns:
A scalar tensor (
dim=0
), the objective function value.
- __abstractmethods__ = frozenset({'solve'})
- __annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_custom_objective': 'bool', '_custom_optimality': 'bool', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_meta_inputs': 'MetaInputsContainer', '_meta_modules': 'dict[str, nn.Module | None]', '_meta_parameters': 'TensorContainer', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'linear_solve': 'LinearSolver | None', 'training': 'bool'}
Linear System Solvers
|
Return a solver function to solve |
|
Return a solver function to solve |
|
Return a solver function to solve |
Indirect Solvers
- torchopt.linear_solve.solve_cg(**kwargs)[source]
Return a solver function to solve
A x = b
using conjugate gradient.This assumes that
A
is a hermitian, positive definite matrix.- Parameters:
ridge (float or None, optional) – Optional ridge regularization. If provided, solves the equation for
A x + ridge x = b
. (default:None
)init (Tensor, tree of Tensor, or None, optional) – Optional initialization to be used by conjugate gradient. If
None
, uses zero initialization. (default:None
)**kwargs (
Any
) – Additional keyword arguments for the conjugate gradient solvertorchopt.linalg.cg()
.
- Return type:
Callable
[[Callable
[[Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]]- Returns:
A solver function with signature
(matvec, b) -> x
that solvesA x = b
using conjugate gradient wherematvec(v) = A v
.
See also
Conjugate gradient iteration
torchopt.linalg.cg()
.Examples
>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} >>> def matvec(x: TensorTree) -> TensorTree: ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} >>> b = matvec(x) >>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) >>> x_hat = solver(matvec, b) >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
- torchopt.linear_solve.solve_normal_cg(**kwargs)[source]
Return a solver function to solve
A^T A x = A^T b
using conjugate gradient.This can be used to solve
A x = b
using conjugate gradient whenA
is not hermitian, positive definite.- Parameters:
ridge (float or None, optional) – Optional ridge regularization. If provided, solves the equation for
A^T A x + ridge x = A^T b
. (default:None
)init (Tensor, tree of Tensor, or None, optional) – Optional initialization to be used by conjugate gradient. If
None
, uses zero initialization. (default:None
)**kwargs (
Any
) – Additional keyword arguments for the conjugate gradient solvertorchopt.linalg.cg()
.
- Return type:
Callable
[[Callable
[[Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]]- Returns:
A solver function with signature
(matvec, b) -> x
that solvesA^T A x = A^T b
using conjugate gradient wherematvec(v) = A v
.
See also
Conjugate gradient iteration
torchopt.linalg.cg()
.Examples
>>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)} >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} >>> def matvec(x: TensorTree) -> TensorTree: ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} >>> b = matvec(x) >>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) >>> x_hat = solver(matvec, b) >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
- torchopt.linear_solve.solve_inv(**kwargs)[source]
Return a solver function to solve
A x = b
using matrix inversion.If
ns = False
, this assumes the matrixA
is a constant matrix and will materialize it in memory.- Parameters:
ridge (float or None, optional) – Optional ridge regularization. If provided, solves the equation for
A x + ridge x = b
. (default:None
)ns (bool, optional) – Whether to use Neumann Series matrix inversion approximation. If
False
, materialize the matrixA
in memory and usetorch.linalg.solve()
instead. (default:False
)**kwargs (
Any
) – Additional keyword arguments for the Neumann Series matrix inversion approximation solvertorchopt.linalg.ns()
.
- Return type:
Callable
[[Callable
[[Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]]- Returns:
A solver function with signature
(matvec, b) -> x
that solvesA x = b
using matrix inversion wherematvec(v) = A v
.
See also
Neumann Series matrix inversion approximation
torchopt.linalg.ns()
.Examples
>>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} >>> def matvec(x: TensorTree) -> TensorTree: ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} >>> b = matvec(x) >>> solver = solve_inv(ns=True, maxiter=10) >>> x_hat = solver(matvec, b) >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b'])
Zero-Order Differentiation
|
Return a decorator for applying zero-order differentiation. |
|
The base class for zero-order gradient models. |
Decorators
- torchopt.diff.zero_order.zero_order(distribution, method='naive', argnums=(0,), num_samples=1, sigma=1.0)[source]
Return a decorator for applying zero-order differentiation.
- Parameters:
distribution (callable or Samplable) – A samplable object that has method
samplable.sample(sample_shape)
or a function that takes the shape as input and returns a shaped batch of samples. This is used to sample perturbations from the given distribution. The distribution should be sphere symmetric.method (str, optional) – The algorithm to use. The currently supported algorithms are
'naive'
,'forward'
, and'antithetic'
. (default:'naive'
)argnums (int or tuple of int, optional) – Specifies arguments to compute gradients with respect to. (default:
0
)num_samples (int, optional) – The number of sample to get the averaged estimated gradient. (default:
1
)sigma (float, optional) – The standard deviation of the perturbation. (default:
1.0
)
- Return type:
- Returns:
A function decorator that enables zero-order gradient estimation.
Zero-order Gradient Module
- class torchopt.diff.zero_order.nn.ZeroOrderGradientModule(*args, **kwargs)[source]
Bases:
Module
,Samplable
The base class for zero-order gradient models.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- classmethod __init_subclass__(method='naive', num_samples=1, sigma=1.0)[source]
Validate and initialize the subclass.
- Return type:
- abstract sample(sample_shape=())[source]
Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.
- __abstractmethods__ = frozenset({'forward', 'sample'})
- __annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
- __parameters__ = ()
- __subclasshook__()
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
Optimizer Hooks
|
Stateless identity transformation that leaves input gradients untouched. |
Replace |
|
|
Return a |
Hook
Gradient Transformation
|
Clip gradient norm of an iterable of parameters. |
|
Replace updates with values |
Transforms
- torchopt.clip_grad_norm(max_norm, norm_type=2.0, error_if_nonfinite=False)[source]
Clip gradient norm of an iterable of parameters.
- Parameters:
max_norm (float) – The maximum absolute value for each element in the update.
norm_type (float, optional) – Type of the used p-norm. Can be
'inf'
for infinity norm. (default:2.0
)error_if_nonfinite (bool, optional) – If
True
, an error is thrown if the total norm of the gradients fromupdates
isnan
,inf
, or-inf
. (default:False
)
- Return type:
GradientTransformation
- Returns:
An
(init_fn, update_fn)
tuple.
Optimizer Schedules
|
Alias polynomial schedule to linear schedule for convenience. |
|
Construct a schedule with polynomial transition from init to end value. |
Schedules
- torchopt.schedule.linear_schedule(init_value, end_value, transition_steps, transition_begin=0)[source]
Alias polynomial schedule to linear schedule for convenience.
- torchopt.schedule.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]
Construct a schedule with polynomial transition from init to end value.
- Parameters:
init_value (float or Tensor) – Initial value for the scalar to be annealed.
end_value (float or Tensor) – End value of the scalar to be annealed.
power (float or Tensor) – The power of the polynomial used to transition from
init
toend
.transition_steps (int) – Number of steps over which annealing takes place, the scalar starts changing at
transition_begin
steps and completes the transition bytransition_begin + transition_steps
steps. Iftransition_steps <= 0
, then the entire annealing process is disabled and the value is held fixed atinit_value
.transition_begin (int, optional) – Must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
init_value
). (default:0
)
- Returns:
A function that maps step counts to values.
- Return type:
schedule
Apply Parameter Updates
|
Apply an update to the corresponding parameters. |
Apply Updates
- torchopt.apply_updates(params, updates, *, inplace=True)[source]
Apply an update to the corresponding parameters.
This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of
GradientTransformations
. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, usingtree_map()
(e.g. if you want to manipulate updates in custom ways before applying them).- Parameters:
- Return type:
Params
- Returns:
Updated parameters, with same structure, shape and type as
params
.
Combining Optimizers
|
Apply a list of chainable update transformations. |
Chain
- torchopt.combine.chain(*transformations)[source]
Apply a list of chainable update transformations.
Given a sequence of chainable transforms,
chain()
returns aninit_fn()
that constructs astate
by concatenating the states of the individual transforms, and returns anupdate_fn()
which chains the update transformations feeding the appropriate state to each.- Parameters:
*transformations (iterable of GradientTransformation) – A sequence of chainable
(init_fn, update_fn)
tuples.- Return type:
GradientTransformation
- Returns:
A single
(init_fn, update_fn)
tuple.
Distributed Utilities
Initialization and Synchronization
|
Return a decorator to automatically initialize RPC on the decorated function. |
|
Synchronize local and remote RPC processes. |
- torchopt.distributed.auto_init_rpc(worker_init_fn=None, worker_name_format=<function default_worker_name_format>, *, backend=None, rpc_backend_options=None)[source]
Return a decorator to automatically initialize RPC on the decorated function.
Process group information
Get the world information. |
|
Get the global world rank of the current worker. |
|
|
Get the global world rank of the current worker. |
Get the world size. |
|
Get the local rank of the current worker on the current node. |
|
Get the local world size on the current node. |
|
|
Get the worker id from the given id. |
- torchopt.distributed.get_world_rank()[source]
Get the global world rank of the current worker.
- Return type:
- torchopt.distributed.get_local_rank()[source]
Get the local rank of the current worker on the current node.
- Return type:
Worker selection
|
Return a decorator to mark a function to be executed only on given ranks. |
|
Return a decorator to mark a function to be executed only on non given ranks. |
|
Return a decorator to mark a function to be executed only on rank zero. |
|
Return a decorator to mark a function to be executed only on non rank zero. |
- torchopt.distributed.on_rank(*ranks)[source]
Return a decorator to mark a function to be executed only on given ranks.
- torchopt.distributed.not_on_rank(*ranks)[source]
Return a decorator to mark a function to be executed only on non given ranks.
Remote Procedure Call (RPC)
|
Asynchronously do an RPC on remote workers and return the a |
|
Do an RPC synchronously on remote workers and return the result to the current worker. |
- torchopt.distributed.remote_async_call(func, *, args=None, kwargs=None, partitioner=None, reducer=None, timeout=-1.0)[source]
Asynchronously do an RPC on remote workers and return the a
torch.Future
instance at the current worker.- Parameters:
func (callable) – The function to call.
args (tuple of object or None, optional) – The arguments to pass to the function. (default:
None
)kwargs (dict[str, object] or None, optional) – The keyword arguments to pass to the function. (default:
None
)partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default:
batch_partitioner()
)reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If
None
, do not reduce the results. (default:None
)timeout (float, optional) – The timeout for the RPC call. (default:
rpc.api.UNSET_RPC_TIMEOUT
)
- Return type:
- Returns:
A
torch.Future
instance for the result. The result is at the current worker.
- torchopt.distributed.remote_sync_call(func, *, args=None, kwargs=None, partitioner=None, reducer=None, timeout=-1.0)[source]
Do an RPC synchronously on remote workers and return the result to the current worker.
- Parameters:
func (callable) – The function to call.
args (tuple of object or None, optional) – The arguments to pass to the function. (default:
None
)kwargs (dict[str, object] or None, optional) – The keyword arguments to pass to the function. (default:
None
)partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default:
batch_partitioner()
)reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If
None
, do not reduce the results. (default:None
)timeout (float, optional) – The timeout for the RPC call. (default:
rpc.api.UNSET_RPC_TIMEOUT
)
- Return type:
- Returns:
The result of the RPC call. The result is at the current worker.
Predefined partitioners and reducers
|
Partition a batch of inputs along a given dimension. |
Partitioner class that partitions a batch of inputs along a given dimension. |
|
|
Reduce the results by averaging them. |
|
Reduce the results by summing them. |
- torchopt.distributed.dim_partitioner(dim=0, *, exclusive=False, keepdim=True, workers=None)[source]
Partition a batch of inputs along a given dimension.
All tensors in the
args
andkwargs
will be partitioned along the dimensiondim
, while the non-tensor values will be broadcasted to partitions.- Parameters:
dim (int, optional) – The dimension to partition. (default:
0
)exclusive (bool, optional) – Whether to partition the batch exclusively. (default:
False
) IfTrue
, the batch will be partitioned intobatch_size
partitions, wherebatch_size
is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. IfFalse
, the batch will be partitioned intomin(batch_size, num_workers)
partitions, wherenum_workers
is the number of workers in the world. Whenbatch_size > num_workers
, there can be multiple batch samples forward in a single RPC call.keepdim (bool, optional) – Whether to keep the partitioned dimension. (default:
False
) IfTrue
, keep the batch dimension. IfFalse
, use select instead of slicing. This functionality should be used withexclusive=True
.workers (sequence of int or str, or None, optional) – The workers to partition the batch to. If
None
, the batch will be partitioned to all workers in the world. (default:None
)
- Return type:
Callable
[...
,Sequence
[Tuple
[int
,Optional
[Tuple
[Any
,...
]],Optional
[Dict
[str
,Any
]]]]]- Returns:
A partition function.
- torchopt.distributed.batch_partitioner(*args: Any, **kwargs: Any) list[tuple[int, Args | None, KwArgs | None]]
Partitioner for batch dimension. Divide and replicates the arguments to all workers along the first dimension.
The batch will be partitioned into
min(batch_size, num_workers)
partitions, wherenum_workers
is the number of workers in the world. Whenbatch_size > num_workers
, there can be multiple batch samples forward in a single RPC call.All tensors in the
args
andkwargs
will be partitioned along the dimensiondim
, while the non-tensor values will be broadcasted to partitions.
Function parallelization wrappers
|
Return a decorator for parallelizing a function. |
|
Return a decorator for parallelizing a function. |
|
Return a decorator for parallelizing a function. |
- torchopt.distributed.parallelize(partitioner=None, reducer=None, timeout=-1.0)[source]
Return a decorator for parallelizing a function.
This decorator can be used to parallelize a function call across multiple workers.
- Parameters:
partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default:
batch_partitioner()
)reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If
None
, do not reduce the results. (default:None
)timeout (float, optional) – The timeout for the RPC call. (default:
rpc.api.UNSET_RPC_TIMEOUT
)
- Return type:
Callable
[[Callable
[...
,TypeVar
(T
)]],Callable
[...
,Union
[list
[TypeVar
(T
)],TypeVar
(U
)]]]- Returns:
The decorator function.
- torchopt.distributed.parallelize_async(partitioner=None, reducer=None, timeout=-1.0)[source]
Return a decorator for parallelizing a function.
This decorator can be used to parallelize a function call across multiple workers. The function will be called asynchronously on remote workers. The decorated function will return a
torch.Future
instance of the result.- Parameters:
partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default:
batch_partitioner()
)reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If
None
, do not reduce the results. (default:None
)timeout (float, optional) – The timeout for the RPC call. (default:
rpc.api.UNSET_RPC_TIMEOUT
)
- Return type:
Callable
[[Callable
[...
,TypeVar
(T
)]],Callable
[...
,Union
[Future
[list
[TypeVar
(T
)]],Future
[TypeVar
(U
)]]]]- Returns:
The decorator function.
- torchopt.distributed.parallelize_sync(partitioner=None, reducer=None, timeout=-1.0)
Return a decorator for parallelizing a function.
This decorator can be used to parallelize a function call across multiple workers.
- Parameters:
partitioner (int, str, or callable, optional) – A partitioner that partitions the arguments to multiple workers. (default:
batch_partitioner()
)reducer (callable or None, optional) – A reducer that reduces the results from multiple workers. If
None
, do not reduce the results. (default:None
)timeout (float, optional) – The timeout for the RPC call. (default:
rpc.api.UNSET_RPC_TIMEOUT
)
- Return type:
Callable
[[Callable
[...
,TypeVar
(T
)]],Callable
[...
,Union
[list
[TypeVar
(T
)],TypeVar
(U
)]]]- Returns:
The decorator function.
Distributed Autograd
|
Context object to wrap forward and backward passes when using distributed autograd. |
|
Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context corresponding to the given |
|
Perform distributed backward pass for local parameters. |
|
Compute and return the sum of gradients of outputs with respect to the inputs. |
- torchopt.distributed.autograd.context()[source]
Context object to wrap forward and backward passes when using distributed autograd. The
context_id
generated in thewith
statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with thiscontext_id
, which is required to correctly execute a distributed autograd pass.- Example::
>>> # xdoctest: +SKIP >>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss])
- torchopt.distributed.autograd.get_gradients(context_id: int) Dict[Tensor, Tensor]
Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context corresponding to the given
context_id
as part of the distributed autograd backward pass.- Parameters:
context_id (int) – The autograd context id for which we should retrieve the gradients.
- Returns:
A map where the key is the Tensor and the value is the associated gradient for that Tensor.
- Example::
>>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = t1 + t2 >>> dist_autograd.backward(context_id, [loss.sum()]) >>> grads = dist_autograd.get_gradients(context_id) >>> print(grads[t1]) >>> print(grads[t2])
- torchopt.distributed.autograd.backward(autograd_ctx_id, tensors, retain_graph=False, inputs=None)[source]
Perform distributed backward pass for local parameters.
Compute the sum of gradients of given tensors with respect to graph leaves.
- Parameters:
autograd_ctx_id (int) – The autograd context id.
tensors (Tensor or sequence of Tensor) – Tensors of which the derivative will be computed.
retain_graph (bool, optional) – If
False
, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option toTrue
is not needed and often can be worked around in a much more efficient way. (default:False
)inputs (Tensor, sequence of Tensor, or None, optional) – Inputs w.r.t. which the gradient be will accumulated into
.grad
. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute thetensors
. (default:None
)
- Return type:
- torchopt.distributed.autograd.grad(autograd_ctx_id, outputs, inputs, retain_graph=False, allow_unused=False)[source]
Compute and return the sum of gradients of outputs with respect to the inputs.
- Parameters:
autograd_ctx_id (int) – The autograd context id.
outputs (Tensor or sequence of Tensor) – Outputs of the differentiated function.
inputs (Tensor or sequence of Tensor) – Inputs w.r.t. which the gradient will be returned (and not accumulated into
.grad
).retain_graph (bool, optional) – If
False
, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option toTrue
is not needed and often can be worked around in a much more efficient way. (default:False
)allow_unused (bool, optional) – If
False
, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. (default:False
)
- Return type:
General Utilities
|
Extract target state. |
|
Recover state. |
|
Stop the gradient for the input object. |
Extract State Dict
- torchopt.extract_state_dict(target, *, by='reference', device=None, with_buffers=True, detach_buffers=False, enable_visual=False, visual_prefix='')[source]
Extract target state.
Since a tensor use
grad_fn
to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is connected bygrad_fn
. Some algorithms requires manually detaching tensors from the computation graph.Note that the extracted state is a reference, which means any in-place operator will affect the target that the state is extracted from.
- Parameters:
target (nn.Module or MetaOptimizer) – It could be a
nn.Module
ortorchopt.MetaOptimizer
.by (str, optional) – The extract policy of tensors in the target. (default:
'reference'
) -'reference'
: The extracted tensors will be references to the original tensors. -'copy'
: The extracted tensors will be clones of the original tensors. This makes the copied tensors havegrad_fn
to be a<CloneBackward>
function points to the original tensors. -'deepcopy'
: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph.device (Device or None, optional) – If specified, move the extracted state to the specified device. (default:
None
)with_buffers (bool, optional) – Extract buffer together with parameters, this argument is only used if the input target is
nn.Module
. (default:True
)detach_buffers (bool, optional) – Whether to detach the reference to the buffers, this argument is only used if the input target is
nn.Module
andby='reference'
. (default:False
)enable_visual (bool, optional) – Add additional annotations, which could be used in computation graph visualization. Currently, this flag only has effect on
nn.Module
but we will supporttorchopt.MetaOptimizer
later. (default:False
)visual_prefix (str, optional) – Prefix for the visualization annotations. (default:
''
)
- Return type:
ModuleState
|tuple
[Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]],...
]- Returns:
State extracted of the input object.
Recover State Dict
- torchopt.recover_state_dict(target, state)[source]
Recover state.
This function is compatible for the
extract_state
.Note that the recovering process is not in-place, so the tensors of the object will not be modified.
- Parameters:
target (nn.Module or MetaOptimizer) – Target that need to recover.
state (ModuleState or sequence of tree of Tensor) – The recovering state.
- Return type:
Stop Gradient
- torchopt.stop_gradient(target)[source]
Stop the gradient for the input object.
Since a tensor use
grad_fn
to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is connected bygrad_fn
. Some algorithms requires manually detaching tensors from the computation graph.Note that the
stop_gradient()
operation is in-place.- Parameters:
target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor) – The target that to be detached from the computation graph, it could be a
nn.Module
,torchopt.MetaOptimizer
, state of thetorchopt.MetaOptimizer
, or just a plain list of tensors.- Return type:
Visualizing Gradient Flow
|
Produce Graphviz representation of PyTorch autograd graph. |
Make Dot
- torchopt.visual.make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50)[source]
Produce Graphviz representation of PyTorch autograd graph.
If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green:
- Blue
Reachable leaf tensors that requires grad (tensors whose
grad
fields will be populated duringbackward()
).
- Orange
Saved tensors of custom autograd functions as well as those saved by built-in backward nodes.
- Green
Tensor passed in as outputs.
- Dark green
If any output is a view, we represent its base tensor with a dark green node.
- Parameters:
var (Tensor or sequence of Tensor) – Output tensor.
params (
Optional
[Mapping
[str
,Tensor
],ModuleState
,Generator
,Iterable
[Union
[Mapping
[str
,Tensor
],ModuleState
,Generator
]]]]) – (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional): Parameters to add names to node that requires grad. (default:None
)show_attrs (bool, optional) – Whether to display non-tensor attributes of backward nodes. (default:
False
)show_saved (bool, optional) – Whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed. (default:
False
)max_attr_chars (int, optional) – If
show_attrs
isTrue
, sets max number of characters to display for any given attribute. (default:50
)
- Return type:
Digraph