Background:
few-shot learning问题简单来说,就是通过减少学习的次数来获得有用的知识。我们以传统的image classification为例,对于传统的问题,我们经过大量的样本对模型进行训练从而增强它的泛化能力。也就是给machine看了很多张图片,让它区分开不同类别图片的区别。这和人学习很不一样,人只需要非常少的图片就可以区分图片里的是猫还是狗。few--shot learning就是给少量的训练样本,也能够具有比较好的泛化能力。如何达到呢?这类问题我们使用解决meta-learning的方式去处理。可以参见meta-learning tutorial中给出的数学定义。
这里用通俗的话来讲,过去的机器学习只是教会了机器分类猫和狗-CNN结构的参数 \(\theta\),meta-learning教会了机器如何分类,这是一个更大框架的参数 \(\phi\) 。这样相当于训练了多个 \(\theta\) ,然后根据 \(\theta\) 的训练情况,得到最好的 \(\phi\) . 可以认为 \(\phi\) 控制 \(\theta\) , \(\phi=f(\theta)\) ,这就是完整训练的过程。将得到的参数 \(\phi\) 去测试集上测试,给几张原来没看过的image,丢到网络里面训练,网络生成的 \(\theta\) 同时会受到 \(\phi\) 的控制, 可以发现可以很快的区分原来没有学习过的image。 并且能够正确的识别出他们的label。
MAML和Reptiles都是解决上述问题比较好的方法。现在把meta-learning的问题用来解决生成问题。通常来说 生成image有三种方法:pixelCNN,VAE和GAN。VAE曾经在以前的文章中介绍过,实质上是衡量的是真实样本和生成样本的KL-divergence。VAE仍然有很多问题需要解决 ,而GAN有更好的生成特性。
Method:
本文使用latent vector作为input,使用reptile的网络架构,将GAN的generator来生成逼真的图像。GAN由generator和discriminator组成。基本思想通俗来说,用Reptiles的方式来训练Generator和Discriminator,得到的 \(\phi_{\mathcal{d}}\) 和 \(\phi_{\mathcal{g}}\), \(W_{d}\) 和\(W_{g}\),使用Wasserstein GP loss进行训练,先进行内层循环K次,得到最佳的 \(W_{d}\) 和 \(W_{g}\) 。通过\(\phi_{\mathcal{d}}-W_{d}\) 和 \(\phi_{\mathcal{g}}-W_{g}\) 来更新参数。其loss function为 \[
\text {minimize} \sum_{T}\left(\Phi_{d}-W_{d \tau}\right)+\left(\Phi_{g}-W_{g \tau}\right)
\] 训练的流程如下:

Experiment:
构造一个ResNetGenerator和ResNetDiscriminator。再分别构造inner_loop和meta_training_loop
核心代码如下:
inner_loop:
def inner_loop(self, real_batch): self.meta_g.train() fake_batch = self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)), dtype=torch.float, device=device)) training_batch = torch.cat([real_batch, fake_batch])
gradient_penalty = calc_gradient_penalty(self.meta_d, real_batch, fake_batch) discriminator_pred = self.meta_d(training_batch) discriminator_loss = wassertein_loss(discriminator_pred, self.discriminator_targets) discriminator_loss += gradient_penalty
self.meta_d_optim.zero_grad() discriminator_loss.backward() self.meta_d_optim.step()
output = self.meta_d(self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)),dtype=torch.float, device=device))) generator_loss = wassertein_loss(output, self.generator_targets)
self.meta_g_optim.zero_grad() generator_loss.backward() self.meta_g_optim.step()
return discriminator_loss.item(), generator_loss.item()
|
meta_training_loop:
def meta_training_loop(self):
data, task = self.env.sample_training_task(self.batch_size) data = normalize_data(data) real_batch = data.to(device)
discriminator_total_loss = 0 generator_total_loss = 0
for _ in range(self.inner_epochs): disc_loss, gen_loss = self.inner_loop(real_batch) discriminator_total_loss += disc_loss generator_total_loss += gen_loss
self.writer.add_scalar('Training_discriminator_loss', discriminator_total_loss, self.eps) self.writer.add_scalar('Training_generator_loss', generator_total_loss, self.eps)
print('epochs:{},Train_D_loss:{},Train_G_loss:{}'.format(self.eps,discriminator_total_loss,generator_total_loss))
for p, meta_p in zip(self.g.parameters(),self.meta_g.parameters()): diff = p - meta_p.cpu() p.grad = diff
self.g_optim.step()
for p, meta_p in zip(self.d.parameters(), self.meta_d.parameters()): diff = p - meta_p.cpu() p.grad = diff self.d_optim.step()
|
可以看到meta_training_loop调用了inner_loop,计算了p - meta_p和p - meta_p的loss,并进行了更新。
真实的training过程
def training(self): while self.eps <= 10000: self.reset_meta_model() self.meta_training_loop() self.eps += 1
if self.eps % 100 == 0: self.reset_meta_model() gan_score=self.validation_run() self.gan_score_total.append(gan_score) print("valscore_total:{}".format(self.gan_score_total)) self.checkpoint_model()
|
Result:
最后生成的效果非常好。Mnist上的测试结果:

Omniglot上的测试结果:
