This section describes useful concepts across TorchOpt.

TorchOpt Types

torchopt.base.GradientTransformation(init, ...)

A pair of pure functions implementing a gradient transformation.

torchopt.base.TransformInitFn(*args, **kwargs)

A callable type for the init() step of a GradientTransformation.

torchopt.base.TransformUpdateFn(*args, **kwargs)

A callable type for the update() step of a GradientTransformation.


PyTrees is an essential concept in TorchOpt. They can be thought as a generalization of vectors. They are a way to structure parameters or weights using tuples and dictionaries. Many solvers in TorchOpt have native support for pytrees.

Floating-Point Precision

TorchOpt uses single (32-bit) floating precision (torch.float32) by default. However, for some algorithms, this may not be enough. Double (64-bit) floating precision (torch.float64) can be enabled by adding the following lines at the beginning of the file:

import torch