Optimizers

The core design of TorchOpt follows the philosophy of functional programming. Aligned with functorch, users can conduct functional-style programming with models, optimizers, and training in PyTorch. We first introduce our functional optimizers, which treat the optimization process as a functional transformation.

Functional Optimizers

Currently, TorchOpt supports 4 functional optimizers: sgd(), adam(), rmsprop(), and adamw().

torchopt.FuncOptimizer(impl, *[, inplace])

A wrapper class to hold the functional optimizer.

torchopt.adadelta([lr, rho, eps, ...])

Create a functional version of the AdaDelta optimizer.

torchopt.adagrad([lr, lr_decay, ...])

Create a functional version of the AdaGrad optimizer.

torchopt.adam([lr, betas, eps, ...])

Create a functional version of the Adam optimizer.

torchopt.adamw([lr, betas, eps, ...])

Create a functional version of the Adam optimizer with weight decay regularization.

torchopt.adamax([lr, betas, eps, ...])

Create a functional version of the AdaMax optimizer.

torchopt.radam([lr, betas, eps, ...])

Create a functional version of the RAdam optimizer.

torchopt.rmsprop([lr, alpha, eps, ...])

Create a functional version of the RMSProp optimizer.

torchopt.sgd(lr[, momentum, dampening, ...])

Create a functional version of the canonical Stochastic Gradient Descent optimizer.

Apply Parameter Updates

TorchOpt offers functional API by passing gradients and optimizer states to the optimizer function to apply updates.

torchopt.apply_updates(params, updates, *[, ...])

Apply an update to the corresponding parameters.

Here is an example of functional optimization coupled with functorch:

class Net(nn.Module): ...

class Loader(DataLoader): ...

net = Net()  # init
loader = Loader()
optimizer = torchopt.adam(lr)

model, params = functorch.make_functional(net)           # use functorch extract network parameters
opt_state = optimizer.init(params)                       # init optimizer

xs, ys = next(loader)                                    # get data
pred = model(params, xs)                                 # forward
loss = F.cross_entropy(pred, ys)                         # compute loss

grads = torch.autograd.grad(loss, params)                # compute gradients
updates, opt_state = optimizer.update(grads, opt_state)  # get updates
params = torchopt.apply_updates(params, updates)         # update network parameters

We also provide a wrapper torchopt.FuncOptimizer to make maintaining the optimizer state easier:

net = Net()  # init
loader = Loader()
optimizer = torchopt.FuncOptimizer(torchopt.adam())      # wrap with `torchopt.FuncOptimizer`

model, params = functorch.make_functional(net)           # use functorch extract network parameters

for xs, ys in loader:                                    # get data
    pred = model(params, xs)                             # forward
    loss = F.cross_entropy(pred, ys)                     # compute loss

    params = optimizer.step(loss, params)                # update network parameters

Classic OOP Optimizers

Combined with the functional optimizers above, we can define our classic OOP optimizers. We designed a base class torchopt.Optimizer that has the same interface as torch.optim.Optimizer. We offer original PyTorch APIs (e.g., zero_grad() or step()) for traditional PyTorch-like (OOP) parameter update.

torchopt.Optimizer(params, impl)

A base class for classic optimizers that similar to torch.optim.Optimizer.

torchopt.AdaDelta(params[, lr, rho, eps, ...])

The classic AdaDelta optimizer.

torchopt.Adadelta

alias of AdaDelta

torchopt.AdaGrad(params[, lr, lr_decay, ...])

The classic AdaGrad optimizer.

torchopt.Adagrad

alias of AdaGrad

torchopt.Adam(params[, lr, betas, eps, ...])

The classic Adam optimizer.

torchopt.AdamW(params[, lr, betas, eps, ...])

The classic AdamW optimizer.

torchopt.AdaMax(params[, lr, betas, eps, ...])

The classic AdaMax optimizer.

torchopt.Adamax

alias of AdaMax

torchopt.RAdam(params[, lr, betas, eps, ...])

The classic RAdam optimizer.

torchopt.RMSProp(params[, lr, alpha, eps, ...])

The classic RMSProp optimizer.

torchopt.SGD(params, lr[, momentum, ...])

The classic SGD optimizer.

By combining low-level API torchopt.Optimizer with the previous functional optimizer, we can achieve high-level API:

learning_rate = 1.0
# High-level API
optim = torchopt.Adam(net.parameters(), lr=learning_rate)
# which can be achieved by low-level API:
optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))

Here is an example of PyTorch-like APIs:

net = Net()  # init
loader = Loader()
optimizer = torchopt.Adam(net.parameters())

xs, ys = next(loader)             # get data
pred = net(xs)                    # forward
loss = F.cross_entropy(pred, ys)  # compute loss

optimizer.zero_grad()             # zero gradients
loss.backward()                   # backward
optimizer.step()                  # step updates

Combining Transformation

Users always need to conduct multiple gradient transformations (functions) before the final update. In the designing of TorchOpt, we treat these functions as derivations of torchopt.chain(). So we can build our own chain like torchopt.chain(torchopt.clip_grad_norm(max_norm=1.), torchopt.sgd(lr=1., moment_requires_grad=True)) to clip the gradient and update parameters using sgd().

torchopt.chain(*transformations)

Apply a list of chainable update transformations.

Note

torchopt.chain() will sequentially conduct transformations, so the order matters. For example, we need to first conduct gradient normalization and then conduct the optimizer step. The order should be (clip, sgd) in torchopt.chain() function.

Here is an example of chaining torchopt.clip_grad_norm() and torchopt.adam() for functional optimizer and OOP optimizer.

func_optimizer = torchopt.chain(torchopt.clip_grad_norm(max_norm=2.0), torchopt.adam(1e-1))
oop_optimizer = torchopt.Optimizer(net.parameters() func_optimizer)

Optimizer Hooks

Users can also add optimizer hook to control the gradient flow.

torchopt.hook.register_hook(hook)

Stateless identity transformation that leaves input gradients untouched.

torchopt.hook.zero_nan_hook(g)

Replace nan with zero.

torchopt.hook.nan_to_num_hook([nan, posinf, ...])

Return a nan to num hook to replace nan / +inf / -inf with the given numbers.

For example, torchopt.hook.zero_nan_hook() registers hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0. Here is an example of such operation coupled with torchopt.chain().

impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))

Optimizer Schedules

TorchOpt also provides implementations of learning rate schedulers, which can be used to control the learning rate during the training process. TorchOpt mainly offers the linear learning rate scheduler and the polynomial learning rate scheduler.

torchopt.schedule.linear_schedule(...[, ...])

Alias polynomial schedule to linear schedule for convenience.

torchopt.schedule.polynomial_schedule(...[, ...])

Construct a schedule with polynomial transition from init to end value.

Here is an example of combining optimizer with learning rate scheduler.

functional_adam = torchopt.adam(
    lr=torchopt.schedule.linear_schedule(
        init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000
    )
)

adam = torchopt.Adam(
    net.parameters(),
    lr=torchopt.schedule.linear_schedule(
        init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000
    ),
)

Notebook Tutorial

Check the notebook tutorial at Functional Optimizer.