基于隐变量后验生成对抗网络的不平衡学习
Unbalanced Learning of Generative Adversarial Network Based on Latent Posterior
通讯作者: 李建勋,男,教授,博士生导师;E-mail:lijx@sjtu.edu.cn.
责任编辑: 陈晓燕
收稿日期: 2019-09-16
基金资助: |
|
Received: 2019-09-16
作者简介 About authors
何新林(1992-),男,湖南省常德市人,硕士生,主要研究方向为数据挖掘. 。
针对现有不平衡分类问题中过采样方法不能充分利用数据概率密度分布的问题,提出了一种基于隐变量后验生成对抗网络的过采样(LGOS)算法.该方法利用变分自编码求取隐变量的近似后验分布,生成器能有效估计数据真实概率分布,在隐空间中采样克服了生成对抗网络采样过程的随机性,并引入边缘分布自适应损失和条件分布自适应损失提升生成数据质量.此外,将生成样本当作源领域样本放入迁移学习框架中,提出了改进的基于实例的迁移学习(TrWSBoost)分类算法,引入了权重缩放因子,有效解决了源领域样本权重收敛过快、学习不充分的问题.实验结果表明,提出的方法在分类问题各指标上的表现明显优于现有方法.
关键词:
Based on the problem that the oversampling method in the existing unbalanced classification problem cannot fully utilize the data probability density distribution, a method named latent posterior based generative adversarial network for oversampling (LGOS) was proposed. This method used variational auto-encoder to obtain the approximate posterior distribution of latent variable and generation network could effectively estimate the true probability distribution function of the data. The sampling in the latent space could overcome the randomness of generative adversarial network. The marginal distribution adaptive loss and the conditional distribution adaptive loss were introduced to improve the quality of generated data. Besides, the generated samples as source domain samples were put into the transfer learning framework, the classification algorithm of transfer learning for boosting with weight scaling (TrWSBoost) was proposed, and the weight scaling factor was introduced, which effectively solved the problem that the weight of source domain samples converge too fast and lead to insufficient learning. The experimental results show that the proposed method is superior to the existing oversampling method in the performance of common metrics.
Keywords:
本文引用格式
何新林, 戚宗锋, 李建勋.
HE Xinlin, QI Zongfeng, LI Jianxun.
尽管现有的解决办法达到了良好的表现[10],鉴于最近几年深度网络生成模型在表示学习上显现出的巨大优势[14,15],本文关注利用深度神经网络对少数类进行过采样,因为过采样不会丢失数据中重要的信息,而且可以作为预处理步骤来进行可视化或者与算法层面方法相结合.传统的采样方法都是基于线性插值的方式,不能根据数据的概率分布函数进行采样.用过采样方法来解决不平衡分类问题是通过生成少数类样本来使数据达到均衡.最简单的方法是复制现有的少数类样本,这种方法容易导致过拟合.Chawla等[5]提出在选定的少数类样本和它们的K近邻之间进行线性插值来生成少数类样本,这种方法把所有少数类样本等同看待,没有考虑数据内分布的差异性,容易导致生成样本落入多数类区域.Han等[6]提出识别出位于类间边界的难以学习的少数类样本,对每个边界集合样本生成同等数量的样本.He等[7]用自适应方法根据每个少数类样本K近邻中多数类样本的数量来决定对每个少数类样本生成样本的数目,这种方法容易受噪声影响,对落入多数类区域的噪声给予过多关注.Barua等[8]提出识别出那些难以学习的少数类样本,并基于其与多数类样本的欧式距离给每个少数类样本分配权重,再用层次聚类法把少数类样本分为若干簇,在簇内根据权重采样插值生成少数类样本.Douzas等[9]提出利用条件生成对抗网络学习数据的多类分布,再进行少数类过采样.
针对现有的基于插值的过采样算法仅仅利用邻域样本的缺点,本文引入了隐变量模型,提出了一种基于隐变量后验生成对抗网络的过采样(LGOS)算法.生成对抗网络利用了所有少数类样本来学习数据真实概率分布,在隐变量后验上采样克服了基于高斯噪声生成对抗网络生成数据的随机性.同时本文引入了权重缩放因子,提出了与过采样算法相结合的不平衡分类算法TrWSBoost,人工合成的过采样样本和原始样本有很大相关性就相当于迁移源领域样本,原始样本被当作目标领域样本来迭代训练集成分类器.
1 生成对抗网络
图1
生成器G和判别器D的训练目标是相互对抗的.判别器对输入样本进行真假判定,通过训练不断提升自己的分类效果,识别出生成器所生成的样本.生成器希望生成更加真实的样本以混淆判别器,让判别器无法分辨真假.设输入的随机噪声为z,生成器G将随机噪声转换为生成样本G(z).判别器D对输入样本输出D(x)为[0,1]范围内的一个实数,表示输入样本为真实样本的概率值.其损失函数为
式中:x为真实输入样本;Pr为真实数据分布;Pz为输入噪声分布;E为数据期望.
两个网络进行迭代训练,理论上最终达到纳什均衡时,生成器G生成的数据分布和真实数据分布相同,判别器D输出概率值为0.5,无法区分真实样本和生成样本.
2 基于隐变量后验的生成对抗网络模型
2.1 隐变量模型
在生成对抗网络中,把高斯噪声或者均匀噪声当作隐变量先验分布,而隐变量真实先验分布和真实后验分布未知,所以生成数据质量具有随机性.变分自编码隐变量模型用近似后验分布代替真实先验分布,运用变分贝叶斯方法,在概率图模型上执行高效的近似推理和学习.均值场方法在很多情况下难以求得后验分布的解析解,变分自编码隐变量模型在概率图框架下形式化这个问题,通过优化对数似然的下界来间接优化最大对数似然.
近似后验分布和真实后验分布的距离用KL散度度量:
式中:ϕ为变分模型参数;qϕ(z|x)为近似后验分布,假设其服从高斯分布;θ为数据生成模型参数;pθ(z|x)为真实后验分布.
通过贝叶斯变换可得变分下界为
隐变量模型通过编码器得到近似后验分布qϕ(z|x)的均值和协方差,在隐空间采样输入解码器重构原始输入数据,误差沿网络反向传播更新网络参数来逼近变分下界.
2.2 基于隐变量后验分布的生成对抗网络模型建立
本文所建立模型中,编码器E从真实样本提取隐变量作为监督信号,在隐空间采样作为信号输入生成器G用来生成和真实数据同分布的样本.隐变量模型的解码器从隐空间采样重构原始输入样本,故可把生成器和解码器结合,提出了一种数据生成模型LGOS.
编码器E输入x,输出为隐变量分布均值和方差,可表示为
式中:qϕ(z|x)~N
在隐空间采样输入生成器G得到生成数据:
式中:
假设隐变量先验分布为正态分布N(0,I),则变分下界为
式中:Lele为重构误差;LKL为隐变量近似后验分布和先验分布之间的KL散度,具体表示如下.
式中:J为隐变量维度;μj和σj分别为样本近似后验分布对应的均值和方差.
判别器D损失为
LD=-
判别器D对真实输入样本输出较大的似然概率值,而对生成器G生成的样本输出小的似然概率值.
生成器G对抗损失为
生成器G和判别器D进行反向迭代,两个模型一直处于对抗训练过程.
2.3 分布自适应
用真实数据和生成数据之间的欧式距离来度量似然函数在很多情况下不适用.因为真实数据和生成数据要服从同分布,在模型中添加边缘分布自适应和条件分布自适应两个限制条件.
边缘分布的差距用最大化均值差异(MMD)度量[17],最大化均值差异把原变量映射到再生希尔伯特空间,在另一空间中求取两个分布的距离.在生成对抗网络中,判别器的目的就是学习数据样本的特征来进行区分,所以在LGOS模型中,用生成对抗网络判别器的最后一个隐层作为特征空间,特征向量的欧式距离即为MMD距离.
式中:l为判别器最后1个隐层;f为输入数据在第l层对应的特征提取函数.
在LGOS模型中,用一个分类器C获得条件概率,分类器输出激活函数为softmax,输出向量各维度表示样本属于各个类别的概率.条件分布距离损失为
分类器用原始数据训练,用交叉熵函数作为其损失函数,分类器损失为
各模块最终损失为
式中:γ1、γ2及γ3为超参数,用于调节各部分损失比重大小.
网络结构图如图2所示.
图2
2.4 权重缩放的迁移学习模型
以TrAdaboost[20]为基础,提出了改进的带权重缩放因子的TrWSBoost迁移学习分类算法.把生成的少数类样本当作源领域样本,原始训练数据当作目标领域样本,目标是要训练迁移学习集成分类器.
在TrWSBoost模型中,在每一轮迭代时,对于源领域样本,被基学习器错分时,认为这些错分样本是与原始样本不同分布的样本,错分样本权重在下一轮迭代时应该降低.正确分类样本权重保持不变.目标领域样本错分时下一轮迭代权重增加,正确分类时权重保持不变.在TrAdaboost算法中,源领域样本错分时权重衰减过快[21,22],且模型融合时仅融合了后一半模型,没有充分利用源领域信息.考虑到本文中源领域样本和目标领域样本较大的相关性,为了解决权重衰减过快的问题,本文以目标领域样本加权错误率和源领域样本加权错误率为基础,设定了权重缩放因子.当目标领域加权错误率低时,认为模型表现良好,减慢源领域样本权重更新速度,反之亦然.
最终算法结构图如图3、4所示,算法流程如下:
图3
图4
图4
训练TrWSBoost集成分类器流程图
Fig.4
Flowchart of training of TrWSBoost ensemble classifier
(1) LGOS过采样算法.
(a) 初始化.设置训练批次大小为m,初始化编码器E、生成器G、判别器D和分类器C 4个网络参数.设置超参数γ1=0.01,γ2=1,γ3=0.02.
(b) 从真实数据中随机抽取批次大小为m的训练数据x,并输入编码器E后得到隐变量近似后验分布z.
(c) 从隐空间采样输入生成器G得到生成样本
(d) 输入真实样本x和生成样本
(e) 输入真实样本x和生成样本
(f) 汇总各网络损失,误差反向传播更新网络参数.
(g) 重复执行步骤(2)~(6),更新网络参数直至收敛.从隐空间采样输入生成器G得到最终生成样本.
(2) TrWSBoost集成分类算法.
(a) 初始化.设目标领域数据为D=
(b) 权重归一化,设归一化权重向量为Pt=
(c) 由D∪S预测标签值和真实标签,根据归一化权重Pt得到源领域样本和目标领域样本加权错误率.目标领域加权错误率为εt=
(d) 更新权重向量.
(e) 当前步数t自增1,未到N步时重复步骤(9)~(11).
(f) 最终集成分类器为H(x)=
3 实验与分析
3.1 实验参数设置和评估指标
实验选取了6个UCI公开数据集,把某类或几类指定为少数类,其他类作为多数类人为制造不平衡数据集,各数据集描述见表1.He等[10]的实验结果显示当数据分布接近均衡时,分类器表现最好,故在本文中,对每一数据集均通过过采样方法使少数类和多数类均衡.在实验中,随机选取80%的数据为训练集,其余数据作为测试集,取10次实验结果平均值作为报告结果.第一阶段实验采用决策树分类器,采用基尼系数作为节点切分标准,叶子节点最少样本数设置为1.SMOTE、Borderline-SMOTE、ADASYN算法K近邻设置为5,对MWMOTE算法K1设为5,K2设为3,K3设为3,聚类簇合并阈值Cp设为3.第二阶段实验把过采样生成的少数类样本当作源领域样本,原训练集数据当作目标领域样本,利用TrWSBoost算法训练迁移学习分类器.为了与第一阶段结果对比,弱分类器同样选取决策树分类器,迭代步数设置为50,当分类精确度很高时提前终止以防止过拟合.
表1 各数据集特性描述
Tab.1
数据集 | 大小 | 特征数 | 多数类 | 少数类 | 不平衡比 |
---|---|---|---|---|---|
phoneme | 5404 | 5 | 3818 | 1586 | 2.41 |
satimage | 4435 | 36 | 3956 | 479 | 8.26 |
pen | 10992 | 16 | 9937 | 1055 | 9.42 |
wine | 6497 | 11 | 5617 | 880 | 6.38 |
letter | 20000 | 16 | 18445 | 1555 | 11.86 |
avila | 10430 | 10 | 9335 | 1095 | 8.53 |
3.2 实验结果与分析
选取satimage数据集,用各过采样算法生成同样数量的少数类样本,利用t分布随机领域嵌入(TSNE)投影算法将数据降到两维进行图形可视化表示.
图5所示为LGOS 算法和其他过采样算法生成数据图形比较.从图中可见,ROS算法生成样本和原始数据中少数类重合,容易导致过拟合.SMOTE算法生成样本相比于原始样本差异小,而且有少部分生成样本落入多数类区域成为噪声,对分类器训练不利.Borderline-SMOTE、MWMOTE及ADASYN算法侧重边界区域少数类样本,这些样本容易受落入多数类区域的噪声影响,对应忽略的噪声较大的权重,导致生成更多噪声,而且容易导致边界混合.从图3(f)中可以看出,生成样本的分布区域基本都在原始少数类样本分布区域内,而且和原始样本的关联更小,说明本文所提出的LGOS算法能够准确估计出真实样本概率密度函数,生成样本时是在真实的概率密度函数上采样,不同于基于插值的方式,生成样本时利用了全局的概率分布,生成样本相比于原始样本差异更大,提供的信息更多.
图5
图5
LGOS 算法和其他过采样算法生成数据图形比较
Fig.5
Visual comparison of synthetic data of LGOS and other oversampling methods
在6个UCI公开数据集上进行对比实验,用过采样算法生成的样本和原始样本混合训练决策树分类器,在测试集上的Recall、F-measure、G-mean和AUC指标见表2,粗体表示最优值.
表2 基于数据过采样的决策树分类器指标
Tab.2
指标 | 数据集 | 原始数据 | ROS | SMOTE | Border | MWMOTE | ADASYN | LGOS |
---|---|---|---|---|---|---|---|---|
Recall | phoneme | 0.7566 | 0.7396 | 0.7953 | 0.7976 | 0.8023 | 0.8046 | 0.8433 |
satimage | 0.9146 | 0.9365 | 0.9414 | 0.9268 | 0.9512 | 0.9524 | 0.9634 | |
pen | 0.9583 | 0.9819 | 0.9814 | 0.9814 | 0.9856 | 0.9625 | 0.9861 | |
wine | 0.6000 | 0.5627 | 0.6511 | 0.6188 | 0.6533 | 0.6583 | 0.6944 | |
letter | 0.9016 | 0.8759 | 0.8983 | 0.8769 | 0.9037 | 0.8586 | 0.9118 | |
avila | 0.9357 | 0.9394 | 0.9564 | 0.9784 | 0.9422 | 0.9697 | 0.9816 | |
F-measure | phoneme | 0.7479 | 0.7394 | 0.7528 | 0.7602 | 0.7627 | 0.7564 | 0.7586 |
satimage | 0.9146 | 0.9411 | 0.9374 | 0.9319 | 0.9414 | 0.9398 | 0.9461 | |
pen | 0.9430 | 0.9718 | 0.9586 | 0.9676 | 0.9755 | 0.9563 | 0.9681 | |
wine | 0.6084 | 0.5885 | 0.5759 | 0.5605 | 0.5945 | 0.5722 | 0.5966 | |
letter | 0.8721 | 0.8898 | 0.8646 | 0.8571 | 0.8772 | 0.8519 | 0.8936 | |
avila | 0.9400 | 0.9483 | 0.9411 | 0.9576 | 0.9280 | 0.9538 | 0.9511 | |
G-mean | phoneme | 0.8241 | 0.8157 | 0.8356 | 0.8398 | 0.8422 | 0.8394 | 0.8486 |
satimage | 0.9521 | 0.9650 | 0.9669 | 0.9596 | 0.9718 | 0.9721 | 0.9778 | |
pen | 0.9749 | 0.9888 | 0.9871 | 0.9881 | 0.9908 | 0.9783 | 0.9902 | |
wine | 0.7510 | 0.7286 | 0.7661 | 0.7483 | 0.7720 | 0.7681 | 0.7897 | |
letter | 0.9432 | 0.9324 | 0.9409 | 0.9301 | 0.9446 | 0.9207 | 0.9500 | |
avila | 0.9642 | 0.9668 | 0.9735 | 0.9853 | 0.9656 | 0.9810 | 0.9859 | |
AUC | phoneme | 0.8271 | 0.8197 | 0.8366 | 0.8410 | 0.8432 | 0.8402 | 0.8486 |
satimage | 0.9529 | 0.9655 | 0.9673 | 0.9602 | 0.9720 | 0.9724 | 0.9779 | |
pen | 0.9751 | 0.9888 | 0.9871 | 0.9881 | 0.9909 | 0.9785 | 0.9902 | |
wine | 0.7700 | 0.7533 | 0.7765 | 0.7621 | 0.7829 | 0.7775 | 0.7963 | |
letter | 0.9442 | 0.9342 | 0.9419 | 0.9317 | 0.9456 | 0.9230 | 0.9508 | |
avila | 0.9646 | 0.9672 | 0.9737 | 0.9854 | 0.9659 | 0.9811 | 0.9860 |
把过采样少数类样本当作源领域样本,原始样本当作目标领域样本,利用本文的TrWSBoost算法训练集成分类器,本文实验中选取决策树作为基分类器,分类器测试指标见表3,粗体表示最优值.其中ROS表示先用ROS过采样算法生成少数类样本,再用所有数据训练TrWSBoost分类器, 其余类同.TrAdaboost列表示用LGOS生成少数类样本,再用TrAdaboost算法训练集成分类器.
表3 基于数据过采样的迁移学习分类器指标
Tab.3
指标 | 数据集 | ROS | SMOTE | Border | MWMOTE | ADASYN | TrAdaboost | LGOS |
---|---|---|---|---|---|---|---|---|
Recall | phoneme | 0.8266 | 0.8333 | 0.8466 | 0.8400 | 0.8500 | 0.8433 | 0.8633 |
satimage | 0.9512 | 0.9390 | 0.9390 | 0.9512 | 0.9634 | 0.9512 | 0.9756 | |
pen | 1.0000 | 1.0000 | 0.9907 | 0.9953 | 1.0000 | 0.9953 | 1.0000 | |
wine | 0.5166 | 0.6166 | 0.6277 | 0.6388 | 0.5944 | 0.6944 | 0.7722 | |
letter | 0.9152 | 0.9186 | 0.9152 | 0.9220 | 0.9322 | 0.9220 | 0.9491 | |
avila | 0.9862 | 0.9862 | 0.9862 | 0.9954 | 0.9862 | 0.9954 | 1.0000 | |
F-measure | phoneme | 0.8378 | 0.8361 | 0.8396 | 0.84 | 0.8388 | 0.8281 | 0.8477 |
satimage | 0.9512 | 0.9565 | 0.9506 | 0.9512 | 0.9634 | 0.9512 | 0.9696 | |
pen | 0.9953 | 0.9976 | 0.9930 | 0.9976 | 0.9953 | 0.9976 | 1.0000 | |
wine | 0.5942 | 0.6646 | 0.6420 | 0.6301 | 0.6114 | 0.5868 | 0.6698 | |
letter | 0.9540 | 0.9559 | 0.9523 | 0.9560 | 0.9649 | 0.9560 | 0.9705 | |
avila | 0.9930 | 0.9907 | 0.9930 | 0.9954 | 0.9907 | 0.9954 | 1.0000 | |
G-mean | phoneme | 0.8832 | 0.8843 | 0.8895 | 0.8879 | 0.8901 | 0.8835 | 0.8976 |
satimage | 0.9728 | 0.9678 | 0.9672 | 0.9728 | 0.9797 | 0.9728 | 0.9858 | |
pen | 0.9994 | 0.9997 | 0.9951 | 0.9976 | 0.9994 | 0.9976 | 1.0000 | |
wine | 0.7058 | 0.7700 | 0.7711 | 0.7739 | 0.7490 | 0.7870 | 0.8402 | |
letter | 0.9565 | 0.9583 | 0.9564 | 0.9599 | 0.9655 | 0.9599 | 0.9739 | |
avila | 0.9930 | 0.9928 | 0.9930 | 0.9974 | 0.9928 | 0.9974 | 1.0000 | |
AUC | phoneme | 0.8851 | 0.8859 | 0.8906 | 0.8892 | 0.8910 | 0.8845 | 0.8983 |
satimage | 0.9731 | 0.9682 | 0.9676 | 0.9731 | 0.9798 | 0.9731 | 0.9859 | |
pen | 0.9994 | 0.9997 | 0.9951 | 0.9976 | 0.9994 | 0.9976 | 1.0000 | |
wine | 0.7404 | 0.7891 | 0.7875 | 0.7881 | 0.7690 | 0.7932 | 0.8432 | |
letter | 0.9574 | 0.9591 | 0.9573 | 0.9607 | 0.9661 | 0.9607 | 0.9743 | |
avila | 0.9931 | 0.9928 | 0.9931 | 0.9974 | 0.9928 | 0.9974 | 1.0000 |
从表2、3中可知,集成后各指标相比于单分类器均有明显提升,表示集成方法是解决不平衡学习的一个好办法.LGOS算法生成样本在集成后在各指标上均超出了其他方法.TrWSBoost算法相比于TrAdaboost算法解决了权重衰减过快的问题.在本文研究中,由于源领域样本和目标领域样本极大的相关性,防止权重衰减过快具有合理性.
4 结语
现有的不平衡分类问题过采样方法均是基于样本间插值的方法,区别在于如何区分需要关注的少数类样本以及每个样本对应的生成样本数量.然而,这些方法均没有有效利用数据的概率密度分布函数,导致生成样本相比于原始样本差异小.基于这一观察以及最近几年深度网络生成模型显现出的优越性,本文提出了一种基于隐变量后验分布生成对抗网络的过采样方法,这一方法在隐空间中采样通过生成器得到生成样本,生成模型能够学习真实样本概率分布函数,故模型能够生成和原始少数类同分布的样本.在6个公开数据集上的对比实验结果及生成数据图形可视化分布均证明了LGOS算法的优越性.另外,提出了改进的基于实例的迁移学习方法,进一步提升了分类器的性能.接下来的工作可以从几方面展开:① 本文仅关注于二类分类问题,可以扩展到多类分类问题;② 改进深度网络处理离散变量的能力以适用于带名义变量的分类问题;③ 该方法在回归问题中的应用.
参考文献
A comprehensive data level analysis for cancer diagnosis on imbalanced data
[J]. ,
Credit risk prediction in an imbalanced social lending environment
[J]. ,DOI:10.2991/ijcis.11.1.70 URL [本文引用: 1]
Progressive boosting for class imbalance and its application to face re-identification
[J]. ,DOI:10.1016/j.eswa.2018.01.023 URL [本文引用: 1]
Performance of machine learning algorithms for class-imbalanced process fault detection problems
[J]. ,DOI:10.1109/TSM.2016.2602226 URL [本文引用: 1]
SMOTE: Synthetic minority over-sampling technique
[J]. ,DOI:10.1613/jair.953 URL [本文引用: 3]
Borderline-SMOTE: A new over-sampling method in imbalanced data sets learning
,
ADASYN: Adaptive synthetic sampling approach for imbalanced learning
,
MWMOTE: Majority weighted minority oversampling technique for imbalanced data set learning
[J]. ,DOI:10.1109/TKDE.2012.232 URL [本文引用: 3]
Effective data generation for imbalanced learning using conditional generative adversarial networks
[J]. ,DOI:10.1016/j.eswa.2017.09.030 URL [本文引用: 2]
Learning from imbalanced data
[J]. ,DOI:10.1109/TKDE.2008.239 URL [本文引用: 5]
Cost-sensitive boosting for classification of imba-lanced data
[J]. ,DOI:10.1016/j.patcog.2007.04.009 URL [本文引用: 1]
SMOTEBoost: Improving prediction of the minority class in boosting
,
RAMOBoost: Ranked minority oversampling in boosting
[J]. ,DOI:10.1109/TNN.2010.2066988 URL [本文引用: 1]
Unpaired image-to-image translation using cycle-consistent adversarial networks
,
StackGAN: Realistic image synjournal with stacked generative adversarial networks
[J]. ,DOI:10.1109/TPAMI.34 URL [本文引用: 1]
Domain adaptation via transfer component analysis
[J]. ,DOI:10.1109/TNN.2010.2091281 URL [本文引用: 1]
Transfer feature learning with joint distribution adaptation
,
Deep transfer learning with joint adaptation networks
,
基于迁移过采样的类别不平衡学习算法研究
[D]. ,
Research on transfer-sampling based method for class-imbalance learning
[D]. ,
/
〈 | 〉 |