Model-Agnostic Meta-Learning
Meta-reinforcement learning has achieved significant successes in various applications. Model-Agnostic Meta-Learning (MAML) [FAL17] is the pioneer one. In this tutorial, we will show how to train MAML on few-shot Omniglot classification with TorchOpt step by step. The full script is at examples/few-shot/maml_omniglot.py.
Contrary to existing differentiable optimizer libraries such as higher, which follows the PyTorch designing which leads to inflexible API, TorchOpt provides an easy way of construction through the code-level.
Overview
There are six steps to finish MAML training pipeline:
Load Dataset: load Omniglot dataset;
Build the Network: build the neural network architecture of model;
Train: meta-train;
Test: meta-test;
Plot: plot the results;
Pipeline: combine step 3-5 together;
In the following sections, we will set up Load Dataset, build the neural network, train-test, and plot to successfully run the MAML training and evaluation pipeline. Here is the overall procedure:
Load Dataset
In your Python code, simply import torch and load the dataset, the full script is at examples/few-shot/support/omniglot_loaders.py:
from .support.omniglot_loaders import OmniglotNShot
import torch
device = torch.device('cuda:0')
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
rng=rng,
device=device,
)
The goal is to train a model for few-shot Omniglot classification.
Build the Network
TorchOpt supports any user-defined PyTorch networks. Here is an example:
import torch, numpy as np
from torch import nn
import torch.optim as optim
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1.0, affine=True),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1.0, affine=True),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1.0, affine=True),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, args.n_way),
).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
Train
Define the train
function:
def train(db, net, meta_opt, epoch, log):
net.train()
n_train_iter = db.x_train.shape[0] // db.batchsz
inner_opt = torchopt.MetaSGD(net, lr=1e-1)
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num = x_spt.size(0)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 5
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
net_state_dict = torchopt.extract_state_dict(net)
optim_state_dict = torchopt.extract_state_dict(inner_opt)
for i in range(task_num):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = net(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
inner_opt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = net(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss)
qry_accs.append(qry_acc.item())
torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
qry_losses = torch.mean(torch.stack(qry_losses))
qry_losses.backward()
meta_opt.step()
qry_losses = qry_losses.item()
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append(
{
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
}
)
Test
Define the test
function:
def test(db, net, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
net.train()
n_test_iter = db.x_test.shape[0] // db.batchsz
inner_opt = torchopt.MetaSGD(net, lr=1e-1)
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num = x_spt.size(0)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
net_state_dict = torchopt.extract_state_dict(net)
optim_state_dict = torchopt.extract_state_dict(inner_opt)
for i in range(task_num):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = net(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
inner_opt.step(spt_loss)
# The query loss and acc induced by these parameters.
qry_logits = net(x_qry[i]).detach()
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss.item())
qry_accs.append(qry_acc.item())
torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
}
)
Plot
TorchOpt supports any user-defined PyTorch networks and optimizers. Yet, of course, the inputs and outputs must comply with TorchOpt’s API. Here is an example:
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
Pipeline
We can now combine all the components together, and plot the results.
log = []
for epoch in range(10):
train(db, net, meta_opt, epoch, log)
test(db, net, epoch, log)
plot(log)
References
Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, ICML 2017, Sydney, NSW, Australia, 6-11 August 2017, volume 70 of Proceedings of Machine Learning Research, 1126–1135. PMLR, 2017. URL: http://proceedings.mlr.press/v70/finn17a.html.