V · 7月15日

持续学习中的Elastic Weight Consolidation Loss数学原理及代码实现

训练人工神经网络最重要的挑战之一是灾难性遗忘。神经网络的灾难性遗忘(catastrophic forgetting)是指在神经网络学习新任务时,可能会忘记之前学习的任务。这种现象特别常见于传统的反向传播算法和深度学习模型中。主要原因是网络在学习新数据时,会调整权重以适应新任务,这可能会导致之前学到的知识被覆盖或忘记,尤其是当新任务与旧任务有重叠时。

image.png

在本文中,我们将探讨一种方法来解决这个问题,称为Elastic Weight Consolidation。EWC提供了一种很有前途的方法来减轻灾难性遗忘,使神经网络在获得新技能的同时保留先前学习任务的知识。

image.png

在任务a和任务B的灰色和黄色区域中,存在许多具有期望的低误差的最优参数配置。假设我们为任务A找到了一个这样的配置θꭺ*,当继续从这样的配置训练模型到新的任务B时,会出现三种不同的场景:

蓝色箭头:简单地继续在任务B上进行训练而不受惩罚,将在任务B的低水平区域结束,但在任务A上的表现低于预期的准确性。

绿色箭头:使用任务A的权重的L2约束可能太强,使得模型在任务A上表现良好,但在任务B上表现不佳。

红色箭头:这是EWC是提出的解决方案,它将在模型在两个任务上都表现良好的区域(两个区域之间的交叉点)中找到参数。

下面我们将解释这是如何完成的。

费雪信息矩阵(FIM)

EWC方法所基于的FIM(Fisher Information Matrix)。FIM是一种统计度量,用于量化给定数据提供的关于我们要估计的未知参数θ的信息量。在持续学习的背景下,FIM将有助于识别神经网络参数,这些参数从以前的任务中获取的数据信息较少。通过更新这些参数,网络可以学习新的任务,而不会删除存储在参数中的重要信息,这些信息是关于先前学习任务的非常有用的信息。

假设X是一个随机变量,其概率密度函数f(X |θ)参数化为θ。样本x的似然函数(仅在数据固定的情况下为参数函数)为:

image.png

当求二阶导数时,基本上是在看似然函数的曲率。

可以考虑下面的两个绘制的似然函数的图表。蓝色曲线表示在峰值附近非常窄的分布,表明数据更有可能在θ附近,并且随着远离θ而迅速减少。相反,黑色曲线代表一个更广泛的分布,即使远离θ,数据也保持相似的可能性。

FIM量化了这个概念——数据是多么严格地限制在某个θ值上。较大的FIM(如蓝色曲线所示)意味着参数值的微小变化将导致数据在这些参数下的可能性显著下降。相反,较小的FIM(如黑色曲线所示)意味着参数值的较小变化将导致可能性的较小降低。

image.png

弹性重量固结

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

弹性权重保持

弹性权重保持(Elastic Weight Consolidation,EWC)是一种用于减轻神经网络灾难性遗忘问题的方法。它的基本思想是在学习新任务时保护先前任务的关键权重。

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

image.png

最后一个是独立于A和B的。这里log(p(B|θ))是任务B的损失,log(p(B))是B的可能性,它可以作为优化的常数,因为它不依赖于θ, log(p(θ| a))是任务a的后验分布,它包含了任务a重要参数的所有信息。

估计log(p(θ|A))比较复杂的,因为计算它将涉及在整个参数空间上对高维函数进行积分。但是它近似为正态分布,其均值为任务a - θꭺ的最优参数,方差为费雪信息矩阵。这种近似是有意义的,因为我们可以假设A和B任务的新参数θ与任务A的最优参数相差不远。在所有θꭺ的参数中,会有一些参数对任务A的良好表现更重要,并且不希望它们改变太多,这就是FIM的作用,FIM的值表明在这种情况下,改变某个参数将如何影响任务A的损失。因此,FIM中值越高的参数变化受到的惩罚越大。

现在,我们对任务A的最优权值进行泰勒展开直到第二项

image.png

其中log(p(θꭺ|A))是一个常数,我们可以在优化中忽略它。我们也可以忽略第二项,因为在最优θꭺ处,梯度为零。这样就找到了log(p(θ|A))的表达式,把它代回到图8的原始公式中:

image.png

第二项的二阶导数为Hessian,可以根据图5的定义用费雪信息矩阵近似。log(p(B|θ))是新任务B的损失,例如交叉熵,我们记为Lᵦ(θ)

我们不需要进行二阶导数,只需根据图4中等价于图5的定义,即对数似然梯度的外积,用一阶导数近似FIM即可:

image.png

λ是一个超参数,表示在前一个任务a上保持精度的重要性。

上面涉及梯度向量的外积的定义捕获了梯度的协方差结构。而FIM的对角线近似通常由梯度的平方给出,它只计算参数的方差,但计算成本较低,足以完成任务:

image.png

Pytorch实现

上面我们介绍了弹性权重保持的数学原理,下面我们来看看Pytorch的代码实现

让我们首先导入一些库以及分别代表任务A和任务B的MNIST和Fashion MNIST数据集。我们还定义了一个简单的神经网络:

 importtorch
 importtorch.nnasnn
 importtorch.nn.functionalasF
 importtorch.optimasoptim
 fromtorchimportautograd
 importnumpyasnp
 fromtorch.utils.dataimportDataLoader
 
 fromtorch.utils.dataimportDataset, DataLoader
 fromtorchvisionimportdatasets, transforms
 fromtqdmimporttqdm
 
 defget_accuracy(model, dataloader):
     model=model.eval()
     acc=0
     forinput, targetindataloader:
         o=model(input.to(device))
         acc+= (o.argmax(dim=1).long() ==target.to(device)).float().mean()
     returnacc/len(dataloader)
 
 classLinearLayer(nn.Module):
     # from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
     def__init__(self, input_dim, output_dim, act='relu', use_bn=False):
         super(LinearLayer, self).__init__()
         self.use_bn=use_bn
         self.lin=nn.Linear(input_dim, output_dim)
         self.act=nn.ReLU() ifact=='relu'elseact
         ifuse_bn:
             self.bn=nn.BatchNorm1d(output_dim)
     defforward(self, x):
         ifself.use_bn:
             returnself.bn(self.act(self.lin(x)))
         returnself.act(self.lin(x))
 
 classFlatten(nn.Module):
     
     defforward(self, x):
         returnx.view(x.shape[0], -1)
     
 classModel(nn.Module):
     
     def__init__(self, num_inputs, num_hidden, num_outputs):
         super(Model, self).__init__()
         self.f1=Flatten()
         self.lin1=LinearLayer(num_inputs, num_hidden, use_bn=True)
         self.lin2=LinearLayer(num_hidden, num_hidden, use_bn=True)
         self.lin3=nn.Linear(num_hidden, num_outputs)
         
     defforward(self, x):
         returnself.lin3(self.lin2(self.lin1(self.f1(x))))
 
 # Load MNIST dataset, representint task A
 mnist_train=datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
 mnist_test=datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
 train_loader=DataLoader(mnist_train, batch_size=100, shuffle=True)
 test_loader=DataLoader(mnist_test, batch_size=100, shuffle=False)
 
 # FashiomMNIST is task B
 f_mnist_train=datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
 f_mnist_test=datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
 f_train_loader=DataLoader(f_mnist_train, batch_size=100, shuffle=True)
 f_test_loader=DataLoader(f_mnist_test, batch_size=100, shuffle=False)

现在让我们在MNIST任务上训练模型:

 # parameters
 EPOCHS=4
 lr=0.001
 weight=100000
 accuracies= {}
 
 device='cuda:1'
 
 criterion=nn.CrossEntropyLoss()
 
 # train model on task A 
 model=Model(28*28, 100, 10).to(device)
 optimizer=optim.Adam(model.parameters(), lr)
 
 for_inrange(EPOCHS):
     forinput, targetintqdm(train_loader):
         output=model(input.to(device))
         loss=criterion(output, target.to(device))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         
 accuracies['mnist_initial'] =get_accuracy(model, test_loader)

现在可以定义函数来估计FIM和EWC损失中使用的先前参数:

 defewc_loss(model, weight, estimated_fishers, estimated_means):
     losses= []
     forparam_name, paraminmodel.named_parameters():
         estimated_mean=estimated_means[param_name]
         estimated_fisher=estimated_fishers[param_name]
         losses.append((estimated_fisher* (param-estimated_mean) **2).sum())
         
     return (weight/2) *sum(losses)
 
 defestimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
     estimated_mean= {}
 
     forparam_name, paraminmodel.named_parameters():
         estimated_mean[param_name] =param.data.clone()
         
     estimated_fisher= {}
     dl=DataLoader(train_ds, batch_size, shuffle=True)
     
     forn, pinmodel.named_parameters():
         estimated_fisher[n] =torch.zeros_like(p)
         
     model.eval()
     fori, (input, target) inenumerate(dl):
         ifi>num_batch:
             break
         model.zero_grad()
 
         output=model(input.to(device))
         # https://www.inference.vc/on-empirical-fisher-information/ - more on this here
         ifESTIMATE_TYPE=='empirical':
             # empirical
             label=target.to(device)
         else:
             # true estimate
             label=output.max(1)[1]
 
         loss=F.nll_loss(F.log_softmax(output, dim=1), label)
         loss.backward()
 
         # accumulate all the gradients
         forn, pinmodel.named_parameters():
             estimated_fisher[n].data+=p.grad.data**2/len(dl)
 
     estimated_fisher= {n: pforn, pinestimated_fisher.items()}
     returnestimated_mean, estimated_fisher

然后继续在任务B上训练EWC损失的网络:

 # compute fisher and mean parameters for EWC loss
 estimated_mean, estimated_fisher=estimate_ewc_params(model, mnist_train)
 
 # Train task B fashion mnist
 for_inrange(EPOCHS):
     forinput, targetintqdm(f_train_loader):
         output=model(input.to(device))
         loss=ewc_loss(model, weight, estimated_fisher, estimated_mean) +criterion(output, target.to(device))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         
 accuracies['mnist_EWC'] =get_accuracy(model, test_loader)
 accuracies['f_mnist_EWC'] =get_accuracy(model, f_test_loader)

可以得到以下精度:

 {'mnist_initial': tensor(0.9772, device='cuda:1'),
  'mnist_AB': tensor(0.9717, device='cuda:1'),
  'f_mnist': tensor(0.8312, device='cuda:1')}

最后将这些与没有EWC损失的模型进行比较:

 {'mnist_initial': tensor(0.9762, device='cuda:1'),
  'mnist_AB': tensor(0.1769, device='cuda:1'),
  'f_mnist': tensor(0.8672, device='cuda:1')}

可以看到EWC损失有助于保持任务A的准确率几乎不变,而学习任务B的准确率几乎与没有EWC损失的情况相同。

总结

我们看到了一种允许神经网络在继续学习新任务的同时保留其先前学习的知识的技术,虽然EWC在解决灾难性遗忘方面效果显著,但仍有一些挑战,例如对费雪信息矩阵的计算和存储需求较高,以及在复杂的深度神经网络结构中的实施复杂性。

还有还有其他方法可以使模型进行持续学习,比如:

重播记忆(Replay Memory):保存旧数据以便周期性地重训练。

联合训练(Joint Training):同时训练网络以处理旧任务和新任务。

元学习方法(Meta-learning Approaches):通过元学习算法来优化模型,以便快速适应新任务而不会忘记旧任务。

这些方法有助于减轻灾难性遗忘的影响,使神经网络能够持续学习和适应多个任务。

https://avoid.overfit.cn/post/56aee34117764e89a1a707c316fa305f

推荐阅读
关注数
4189
内容数
867
SegmentFault 思否旗下人工智能领域产业媒体,专注技术与产业,一起探索人工智能。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息