Explicit Gradient Differentiation

Explicit Gradient


The idea of explicit gradient is to treat the gradient step as a differentiable function and try to backpropagate through the unrolled optimization path. Namely, given

\[\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \boldsymbol{\theta}_0 - \alpha \sum_{i=0}^{K-1} \nabla_{\boldsymbol{\theta}_i} \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}_i),\]

we would like to compute the gradient \(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\). This is usually done by AutoDiff through an inner optimization’s unrolled iterates.

Differentiable Functional Optimizers

By passing the argument inplace as False to the update functions, we can make the optimization differentiable. Here is an example of making torchopt.adam() differentiable.

opt = torchopt.adam()
# Define meta and inner parameters
meta_params = ...
fmodel, params = make_functional(model)
# Initialize optimizer state
state = opt.init(params)

for iter in range(iter_times):
    loss = inner_loss(fmodel, params, meta_params)
    grads = torch.autograd.grad(loss, params)
    # Apply non-inplace parameter update
    updates, state = opt.update(grads, state, inplace=False)
    params = torchopt.apply_updates(params, updates)

loss = outer_loss(fmodel, params, meta_params)
meta_grads = torch.autograd.grad(loss, meta_params)

Differentiable OOP Meta-Optimizers

For PyTorch-like API (e.g., step()), we designed a base class torchopt.MetaOptimizer to wrap our functional optimizers to become differentiable OOP meta-optimizers.

torchopt.MetaOptimizer(module, impl)

The base class for high-level differentiable optimizers.

torchopt.MetaAdaDelta(module[, lr, rho, ...])

The differentiable AdaDelta optimizer.


alias of MetaAdaDelta

torchopt.MetaAdaGrad(module[, lr, lr_decay, ...])

The differentiable AdaGrad optimizer.


alias of MetaAdaGrad

torchopt.MetaAdam(module[, lr, betas, eps, ...])

The differentiable Adam optimizer.

torchopt.MetaAdamW(module[, lr, betas, eps, ...])

The differentiable AdamW optimizer.

torchopt.MetaAdaMax(module[, lr, betas, ...])

The differentiable AdaMax optimizer.


alias of MetaAdaMax

torchopt.MetaRAdam(module[, lr, betas, eps, ...])

The differentiable RAdam optimizer.

torchopt.MetaRMSProp(module[, lr, alpha, ...])

The differentiable RMSProp optimizer.

torchopt.MetaSGD(module, lr[, momentum, ...])

The differentiable Stochastic Gradient Descent optimizer.

By combining low-level API torchopt.MetaOptimizer with the previous functional optimizer, we can achieve high-level API:

# Low-level API
optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))

# High-level API
optim = torchopt.MetaSGD(net, lr=1.0)

Here is an example of using the OOP API torchopt.MetaAdam to conduct meta-gradient calculation.

# Define meta and inner parameters
meta_params = ...
model = ...
# Define differentiable optimizer
opt = torchopt.MetaAdam(model)

for iter in range(iter_times):
    # Perform the inner update
    loss = inner_loss(model, meta_params)

loss = outer_loss(model, meta_params)

CPU/GPU Accelerated Optimizer

TorchOpt performs the symbolic reduction by manually writing the forward and backward functions using C++ OpenMP (CPU) and CUDA (GPU), which largely increase meta-gradient computational efficiency. Users can use accelerated optimizer by setting the use_accelerated_op as True. TorchOpt will automatically detect the device and allocate the corresponding accelerated optimizer.

# Check whether the `accelerated_op` is available:


net = Net(1).cuda()
optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)

General Utilities

We provide the torchopt.extract_state_dict() and torchopt.recover_state_dict() functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set by='copy' to extract the copy of the state dictionary or set by='deepcopy' to have a detached copy.

torchopt.extract_state_dict(target, *[, by, ...])

Extract target state.

torchopt.recover_state_dict(target, state)

Recover state.


Stop the gradient for the input object.

Here is an usage example.

net = Net()
x = nn.Parameter(torch.tensor(2.0), requires_grad=True)

optim = torchopt.MetaAdam(net, lr=1.0)

# Get the reference of state dictionary
init_net_state = torchopt.extract_state_dict(net, by='reference')
init_optim_state = torchopt.extract_state_dict(optim, by='reference')
# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies
init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)

# Set `copy` to get the copy of the state dictionary
init_net_state_copy = torchopt.extract_state_dict(net, by='copy')
init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')

# Set `deepcopy` to get the detached copy of state dictionary
init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')
init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')

# Conduct 2 inner-loop optimization
for i in range(2):
    inner_loss = net(x)

print(f'a = {net.a!r}')

# Recover and reconduct 2 inner-loop optimization
torchopt.recover_state_dict(net, init_net_state)
torchopt.recover_state_dict(optim, init_optim_state)

for i in range(2):
    inner_loss = net(x)

print(f'a = {net.a!r}')  # the same result

Notebook Tutorial

Check the notebook tutorials at Meta Optimizer and Stop Gradient.