Optimizers
The core design of TorchOpt follows the philosophy of functional programming.
Aligned with functorch
, users can conduct functional-style programming with models, optimizers, and training in PyTorch.
We first introduce our functional optimizers, which treat the optimization process as a functional transformation.
Functional Optimizers
Currently, TorchOpt supports 4 functional optimizers: sgd()
, adam()
, rmsprop()
, and adamw()
.
|
A wrapper class to hold the functional optimizer. |
|
Create a functional version of the AdaDelta optimizer. |
|
Create a functional version of the AdaGrad optimizer. |
|
Create a functional version of the Adam optimizer. |
|
Create a functional version of the Adam optimizer with weight decay regularization. |
|
Create a functional version of the AdaMax optimizer. |
|
Create a functional version of the RAdam optimizer. |
|
Create a functional version of the RMSProp optimizer. |
|
Create a functional version of the canonical Stochastic Gradient Descent optimizer. |
Apply Parameter Updates
TorchOpt offers functional API by passing gradients and optimizer states to the optimizer function to apply updates.
|
Apply an update to the corresponding parameters. |
Here is an example of functional optimization coupled with functorch
:
class Net(nn.Module): ...
class Loader(DataLoader): ...
net = Net() # init
loader = Loader()
optimizer = torchopt.adam(lr)
model, params = functorch.make_functional(net) # use functorch extract network parameters
opt_state = optimizer.init(params) # init optimizer
xs, ys = next(loader) # get data
pred = model(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
grads = torch.autograd.grad(loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = torchopt.apply_updates(params, updates) # update network parameters
We also provide a wrapper torchopt.FuncOptimizer
to make maintaining the optimizer state easier:
net = Net() # init
loader = Loader()
optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer`
model, params = functorch.make_functional(net) # use functorch extract network parameters
for xs, ys in loader: # get data
pred = model(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
params = optimizer.step(loss, params) # update network parameters
Classic OOP Optimizers
Combined with the functional optimizers above, we can define our classic OOP optimizers.
We designed a base class torchopt.Optimizer
that has the same interface as torch.optim.Optimizer
.
We offer original PyTorch APIs (e.g., zero_grad()
or step()
) for traditional PyTorch-like (OOP) parameter update.
|
A base class for classic optimizers that similar to |
|
The classic AdaDelta optimizer. |
|
alias of |
|
The classic AdaGrad optimizer. |
|
alias of |
|
The classic Adam optimizer. |
|
The classic AdamW optimizer. |
|
The classic AdaMax optimizer. |
|
alias of |
|
The classic RAdam optimizer. |
|
The classic RMSProp optimizer. |
|
The classic SGD optimizer. |
By combining low-level API torchopt.Optimizer
with the previous functional optimizer, we can achieve high-level API:
learning_rate = 1.0
# High-level API
optim = torchopt.Adam(net.parameters(), lr=learning_rate)
# which can be achieved by low-level API:
optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))
Here is an example of PyTorch-like APIs:
net = Net() # init
loader = Loader()
optimizer = torchopt.Adam(net.parameters())
xs, ys = next(loader) # get data
pred = net(xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
optimizer.zero_grad() # zero gradients
loss.backward() # backward
optimizer.step() # step updates
Combining Transformation
Users always need to conduct multiple gradient transformations (functions) before the final update.
In the designing of TorchOpt, we treat these functions as derivations of torchopt.chain()
.
So we can build our own chain like torchopt.chain(torchopt.clip_grad_norm(max_norm=1.), torchopt.sgd(lr=1., moment_requires_grad=True))
to clip the gradient and update parameters using sgd()
.
|
Apply a list of chainable update transformations. |
Note
torchopt.chain()
will sequentially conduct transformations, so the order matters.
For example, we need to first conduct gradient normalization and then conduct the optimizer step.
The order should be (clip, sgd) in torchopt.chain()
function.
Here is an example of chaining torchopt.clip_grad_norm()
and torchopt.adam()
for functional optimizer and OOP optimizer.
func_optimizer = torchopt.chain(torchopt.clip_grad_norm(max_norm=2.0), torchopt.adam(1e-1))
oop_optimizer = torchopt.Optimizer(net.parameters() func_optimizer)
Optimizer Hooks
Users can also add optimizer hook to control the gradient flow.
Stateless identity transformation that leaves input gradients untouched. |
|
Replace |
|
|
Return a |
For example, torchopt.hook.zero_nan_hook()
registers hook to the first-order gradients.
During the backpropagation, the NaN gradients will be set to 0.
Here is an example of such operation coupled with torchopt.chain()
.
impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))
Optimizer Schedules
TorchOpt also provides implementations of learning rate schedulers, which can be used to control the learning rate during the training process. TorchOpt mainly offers the linear learning rate scheduler and the polynomial learning rate scheduler.
|
Alias polynomial schedule to linear schedule for convenience. |
|
Construct a schedule with polynomial transition from init to end value. |
Here is an example of combining optimizer with learning rate scheduler.
functional_adam = torchopt.adam(
lr=torchopt.schedule.linear_schedule(
init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000
)
)
adam = torchopt.Adam(
net.parameters(),
lr=torchopt.schedule.linear_schedule(
init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000
),
)
Notebook Tutorial
Check the notebook tutorial at Functional Optimizer.