机器之心编辑部发布

CMU团队解析CNN泛化能力:一切秘密都在数据中

如何理解神经网络的泛化能力?CMU 的汪浩瀚、邢波等人在论文《High-frequency Component Helps Explain the Generalization of Convolutional Neural Network》中另辟蹊径,从数据的角度入手,探讨那些曾让我们百思不得其解的泛化现象。

论文链接:https://arxiv.org/pdf/1905.13545.pdf

都是数据惹的祸

神经网络泛化能力的理解一向是众多机器学习研究者追求的目标,而致力于解决相关问题的文章也已是浩如烟海。本文将考虑一个完全不同的角度:只需在图片数据上执行一个小小的操作,就能够帮助我们理解许多曾经百思不解的泛化现象(比如对抗攻击和 Batch Normalization 的功效等)。很多现象难以解释,可能是因为我们把数据想象的太简单了,简而言之:数据可能比模型更复杂

我们首先从一个有趣的现象开始(如图 1 所示):在 CIFAR10 数据集上训练一个 ResNet18 模型。我们选择一张测试图片,把这张图片输入进训练好的模型中,绘制出模型的预测概率,我们看到模型有很大的信心认为图片上的物体是「汽车」。如果我们将这张图片经过傅里叶变换投影到频域中,然后选取一个半径,用这个半径把频域切分称低频域和高频域;然后利用切分之后的频域重新构建图片,其中第二行的蓝色汽车是低频重建的图片,而第三行的黑色块是高频重建的图片(高频重建的图片几乎全部是肉眼不可见的噪音)。将这些重建的图片重新放回刚刚的模型中,我们可以发现一个很有趣的现象:第二行肉眼看起来与原图相似的低频重建图片被预测成了「青蛙」,而第三行肉眼完全无法识别的图片被模型预测成了「汽车」,与原图预测结果一致。 

图 1. 人与模型视觉上的区别:低频重建的图片与原图看起来几乎一致,却被模型预测成了不同的 label。高频重建的图片人眼几乎无法识别,模型却能成功预测出原来的 label。

尽管我们只在大约 600 张图片中发现了这个神奇的现象,但也足够引起警觉了。

为什么会出现这种情况?

我们认为其中一个重要原因是在一个有限的数据集中存在着高频信息与图片所表达语义的相互关联(如图 2 所示):在一个同分布的有限数据集中可能存在着一些无法察觉的高频信号刚好与图片的语义有些关联,进而导致了这些高频信号与数据集 label 的相互关联。这样,当我们优化模型去降低损失函数时,我们并没有明确告知模型去学习语义还是高频信号,这导致模型随意学习各种可能会降低损失的信号。这样,尽管模型可能会达到较高的准确率,但它未必真的理解数据。

图 2. 本文的最主要观点:在一个有限的数据集中,除了数据本身的内容和 label 的关联之外,还有一些数据内容和其中的高频噪音的关联。当训练一个模型的时候,如果我们没有指明模型应该学习数据想表达的内容还是这些高频信号,模型可能会无差别地学习数据本身的信号或者这些高频信号,而这将使得对模型泛化能力的评估出现各种难以解释的现象。而当我们理解了这种数据特征之后,诸如对抗攻击、Batch Normalization 的功效等泛化的多种神奇现象将迎刃而解。

请注意:我们并没有说模型有捕捉高频信号的倾向性。这里的主要观点是:模型并没有任何理由忽略高频信息,从而导致模型学到了高频和语义的混合信息。

那么,这种现象是否是好事呢?一方面,这个特点有机会让我们创造出能够超越人类视觉系统的模型,在同分布的数据集上甚至可以达到比人类水平更高的准确率。另一方面,这样的模型,尽管在同分布数据集上效果很好,但在其他不同分布的相似数据集上效果可能参差不齐(高频信息很可能在每个分布上是不一致的)。在这里,我们并不争辩哪个观点更正确,我们只是提供这些观测结果供大家探讨。

关于泛化的一些解释

这个特点可以用来解释很多关于泛化的有趣现象,本文只涉及论文中的两点。

对抗性攻击的根源

显而易见地,上述现象和对抗性攻击有密切的联系。

为了更好地探究这些联系,我们利用对抗性训练(adversarial training, PGD)训练了一个可以抵御对抗性攻击的模型,进而研究该模型的卷积核特征,并比较了该卷积核与普通模型的卷积核的差别。我们发现,一个对对抗攻击鲁棒性更强的模型的卷积核看起来更加平滑(平滑指的是相邻位置的权重非常相似),如图 3 所示。很多数学工具可以帮我们证明平滑的卷积核能够有效地移除高频信号。这些结果将对抗性攻击的研究和数据中的高频信号联系了起来。

图 3. 左:普通卷积神经网络的卷积核的可视化;右:对对抗攻击鲁棒的卷积神经网络的卷积核的可视化。

有了这些观测结果,一个更有趣的问题就是我们是否可以通过将卷积核变得更加平滑来提高模型的对抗攻击鲁棒性。为了探究这个问题,我们测试了如下三种方法:

  • 对于一个训练好的模型,我们调整其权重,使卷积核变得更加平滑;

  • 直接在训练好的卷积核上将高频信息过滤掉;

  • 在训练卷积神经网络的过程中增加正则化,使得相邻位置的权重更加接近。

我们很希望这些调整可以增加模型的对抗鲁棒性,可是很遗憾,我们只观测到了很小一部分的提升。这样,我们大概可以得出结论:对抗鲁棒性较好的模型卷积核更加平滑,然而卷积核更加平滑的模型对抗鲁棒性未必更好。换言之:高频信息是对抗攻击的一部分,但并非全部

然而,另一个策略确确实实可以提高模型的对抗鲁棒性:

  • 先将图片的高频信息过滤掉,再把它们放入模型中。

值得说明的是,这个方法虽然有机会在工业界大展拳脚,但是对抗攻击和防御的科研社区通常不太认可类似的图片预处理方法。

关于 Batch Normalization 的神奇功效

另一个很有趣的现象是 BatchNorm 的效果究竟来源于何处。BatchNorm 是当下深度学习,尤其是计算机视觉中的一个相当重要的技术。尽管成效斐然,BatchNorm 在训练中起到的作用似乎还没有得到很好地理解。我们的实验很有趣地从数据的角度上解释了 BatchNorm 的功效来自何方。

图 4. 在训练过程中,测试准确率随着 epoch 数的变化。每一个板块描述的是一个不同的训练技巧。颜色代表着区分低频信息和高频信息的半径。实线代表低频信息,虚线代表高频信息。虚线越高,表示越多的高频信息被学习到了。

在图 4 中,随着训练 epoch 数的增加,我们汇报了在训练过程中不同测试集的测试准确率,这些测试集是由不同的半径生成的,其中实线代表了低频数据,虚线代表了高频信息。虚线越高,一个模型就学到了越多的高频信息。

很意外的是,利用 BatchNorm 训练出来的模型学到了大量的高频信息:我们可以看到,BatchNorm 对应板块的虚线远远高于其他板块的虚线。这些结果说明 BatchNorm 之所以能够如此有效地提高模型的准确率,可能是在鼓励模型大量使用高频信息。正如前文所说的,在一个数据集里有各种信号,如果一个模型能利用更多的信号,那么它很有可能具备更高的准确率,这也符合我们所熟知的 BatchNorm 能够有效提高测试准确率的特点。

直观上来讲,我们猜测 BatchNorm 的优势来源于高频信息的像素值通常比较小(比如在图 1 中,高频重构的图片几乎只是一个黑色的方块)。而 BatchNorm 可能通过 normalization 提高了这个较小的值,使得模型更容易学到相关的信息。

那么 BatchNorm 的这个特点代表了什么呢?我们觉得这个特点可能要让工业界和学术界重新审视 BatchNorm 的效果,尤其是要训练的模型在很多数据集上需要有较为稳定的表现的时候。我们发现的这个结论也与其他的结果相互关联,如 Batch Normalization is a Cause of Adversarial Vulnerability。

结论

论文还讨论了一些其他的相关问题,例如著名的「rethinking the generalization」论文提到的问题,模型鲁棒性和准确率之间的平衡,还有相关现象在目标检测领域中的讨论。

本文主要总结了下面几个结论:

  • 由于高频信息很可能随着数据集的变化而变化,SOTA 可能没有我们想象的那么重要,而模型的表现和人的视觉能否相互呼应要重要的多。

  • 对于计算机视觉,我们可能需要新的测试模式,比如在原有测试集的基础上,同样测试低频信息重构之后的数据集。

  • 未来,更加直观地把人的视觉特征加入模型中的技术可能会比较重要。

理论泛化能力CNNCMU
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

重构技术

代码重构(英语:Code refactoring)指对软件代码做任何更动以增加可读性或者简化结构而不影响输出结果。 软件重构需要借助工具完成,重构工具能够修改代码同时修改所有引用该代码的地方。在极限编程的方法学中,重构需要单元测试来支持。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

计算机视觉技术

计算机视觉(CV)是指机器感知环境的能力。这一技术类别中的经典任务有图像形成、图像处理、图像提取和图像的三维推理。目标识别和面部识别也是很重要的研究领域。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

傅里叶变换技术

傅里叶变换(法语:Transformation de Fourier、英语:Fourier transform)是一种线性积分变换,用于信号在时域(或空域)和频域之间的变换,在物理学和工程学中有许多应用。因其基本思想首先由法国学者约瑟夫·傅里叶系统地提出,所以以其名字来命名以示纪念。实际上傅里叶变换就像化学分析,确定物质的基本成分;信号来自自然界,也可对其进行分析,确定其基本成分。

卷积神经网络技术

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

噪音技术

噪音是一个随机误差或观测变量的方差。在拟合数据的过程中,我们常见的公式$y=f(x)+\epsilon$中$\epsilon$即为噪音。 数据通常包含噪音,错误,例外或不确定性,或者不完整。 错误和噪音可能会混淆数据挖掘过程,从而导致错误模式的衍生。去除噪音是数据挖掘(data mining)或知识发现(Knowledge Discovery in Database,KDD)的一个重要步骤。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

推荐文章
暂无评论
暂无评论~