背景介绍
Skip-gram(SG)模型也叫连续跳字模型,它来源于 NLP(Natural Language Processing,自然语言处理)领域的 Word2Vec 词嵌入算法,该算法能将单词嵌入到向量空间,并使得语义越相近的单词在向量空间相距越近。Word2Vec 算法是由 Google 公司的 T. Mikolov 等人于 2013 年提出的,它实际上包含两个模型——Skip-gram 和 CBOW,该算法的相关资料如下:
- T. Mikolov, K. Chen, G. Corrado, J. Dean, Efficient Estimation of Word Representations in Vector Space (2013)
- X. Rong, word2vec Parameter Learning Explained (2016)
在图嵌入领域,2014 年提出的 DeepWalk 算法首次使用 Skip-gram 模型训练图中节点的向量表示,其核心思想是:将图中的节点视为一个个的“单词”,每个节点进行随机游走生成的节点序列相当于“句子”,一系列节点序列就组成“语料库”;将“语料库”送入 Skip-gram 模型进行训练,就可以得到每个节点的向量表示。该算法的相关资料如下:
- B. Perozzi, R. Al-Rfou, S. Skiena, DeepWalk: Online Learning of Social Representations (2014)
后续提出的 Node2Vec、Struc2Vec 等图嵌入算法对 DeepWalk 进行了若干改进,但都仍然使用 Skip-gram 模型。
本文介绍 Skip-gram 模型时,使用其原始应用的自然语言作为例子。
模型概览
Skip-gram 模型的基本模式是通过给定的目标词预测出它的若干个上下文单词。如下图所示:将一个单词 w(t) 输入到模型,模型输出该单词的 4 个上下文单词 w(t-2)、w(t-1)、w(t+1) 和 w(t+2);其中 +/- 分别表示目标词的上下文,上下文单词数量是可以调整的。
值得留意的是,Skip-gram 模型训练的最终目的并不是使用模型做预测,而是获得中间映射关系(即上图的 PROJECTION)中包含的一个权重矩阵,该权重矩阵代表每个单词的向量表示。
模型训练——原始形式
Skip-gram 模型训练采用反向传播算法(BP 算法),如果读者不熟悉 BP 算法,为了更好地理解以下内容,建议先阅读文档——反向传播算法。
语料库
假设我们的语料库中一共有 10 个单词:
graph, is, a, good, way, to, visualize, data, very, at
还有一系列包含这些单词的句子,其中一个句子是:
Graph is a good way to visualize data.
滑动窗口采样
Skip-gram 模型使用滑动窗口(Sliding Window)对数据进行采样:窗口依次以句子中的每个单词为目标,将目标词前后的 window_size
个单词与其进行组合。
下图展示 window_size
= 1 时的采样结果。请注意,如果 window_size
> 1,采样结果并不区分窗口内各单词距离目标词的距离。
独热编码
单词并不能直接输送到模型中,需要进行机器编码。Skip-gram 采用独热编码(One-hot Encoding)对单词进行预处理,在本例中,10 个单词的独热编码如下:
模型描述
Skip-gram 模型的细节如下:
其中,
- V:语料库中的单词总数
- N:隐藏层向量维度,也是单词嵌入向量的维度
- C:输出向量的数目,即上下文单词数量,与采样窗口大小有关
- x:输入单词的 1×V 维独热编码向量
- h:隐藏层 1×N 维向量
- u:未经过 Softmax 函数处理的输出向量
- y: 经过 Softmax 函数处理后的实际输出向量;输出 C 个相同的 y (记作 yc),每个 yc 代表一个上下文单词
- W:输入层和隐藏层之间的 V×N 维权重矩阵
- W':隐藏层和输出层之间的 N×V 维权重矩阵
关于权重矩阵:矩阵 W 的每一行 vw 就是一个单词的 N 维嵌入向量(行顺序与独热编码顺序一致),vw 称为单词的输入向量。矩阵 W' 的每一列 v'w 则可以看成是单词的另一种 N 维向量表示(列顺序与独热编码顺序一致),v'w 称为单词的输出向量。两个矩阵的各权重值都是随机进行初始化的,W 是 W' 两个不同的矩阵,W' 并不是 W 的转置。模型训练的过程就是不断优化两个矩阵权重值的过程,训练的最终目的是得到矩阵 W。
的确,两个矩阵相互独立意味着每个单词有两种嵌入向量表示:输入向量 vw 可视为单词作为目标词时的向量表示,而输出向量 v'w 则是单词作为上下文时的向量表示。实际上,两个矩阵完全独立不仅使得计算更容易,结果也更准确。
关于 Softmax 函数:Softmax 是一种激活函数,可以将一个数值向量归一化为一个概率分布向量,各概率之和为 1。Softmax 函数的公式如下
前向传播过程
以训练样本 (is, graph) 和 (is, a) 为例:
输入层 → 隐藏层
向量 x 与矩阵 W 的转置 WT 相乘得到隐藏层向量 h。由于 x 是一个独热编码向量,即 xk = 1,其余位置都是 0,此步骤相当于把矩阵 W 的第 k 行向量提取出来,不需要计算。
隐藏层 → 输出层
向量 h 与 W' 的转置 W'T相乘得到向量 u:
向量 u 的各分量 uj 可看成是一系列分数,再使用 Softmax 函数处理后得到最终的输出向量 y:
向量 y 的各分量代表当输入为 x 时,输出其他每个单词的概率,所有概率的和为 1(独热编码顺序)。
本例中,获得最高概率的两个词 good 和 visualize 即为模型本次输出的两个上下文词,与期望的输出 graph 和 a 不一致,接着会通过反向传播调整矩阵 W 和 W' 的权重值。
损失函数
一个上下文单词
先考虑一个简单的情况:模型只输出一个向量 y 时,yj 表示向量 y 的第 j 个分量,也就是输出语料库中第 j 个单词的概率。假设期望输出的是第 j* 个单词,则模型的训练目标就是最大化 yj*,即
在机器学习中,最小化目标通常比最大化目标更容易处理,因此我们对上述目标进行一些变换:
取 yj 的对数并不影响目标,并且由于 yj ∈(0,1),其对数是一个负数,这样就巧妙地转换成求解最小值的问题。
由此得到模型的损失函数 E 如下:
计算 E 对 uj 的偏导数,并将结果定义为 ej:
上述求偏导的过程涉及到求对数和指数函数的导数,读者需熟悉这一点。
考虑一下 ej 的实际意义。以训练样本 (is, graph) 为例,期望的输出 graph 对应第 1 位,因此有
本例中输出的向量 y 如下,其中 y4 最大,即模型实际输出第 4 个单词 good。可以看出,ej 的计算相当于用输出向量 y 减去期望输出单词的独热编码,ej 就是模型的预测误差。
多个上下文单词
模型输出多个向量 y 时,yc,j 表示输出的第 c 个向量 yc 的第 j 个分量,同样地,这是输出语料库中第 j 个单词的概率。我们之前提到过,一个滑动窗口内所有的采样结果是同等的。对于 C 个上下文单词,模型的训练目标是使得每个上下文单词对应的概率都达到最大,也就是使得这些概率的乘积最大化,即
类似地,经过变换后得到损失函数 E 为:
E 对 uj 求偏导:
也就是说,当输出多个上下文单词时,模型预测误差相当于预测每个上下文单词的误差总和。
反向传播过程
在反向传播中,调整模型参数,即矩阵 W 和 W' 的各权重值。Skip-gram 采用随机梯度下降法(SGD)更新权重,以下说明采用上面举例的样本。
如果读者不熟悉梯度下降法,请先阅读文档——梯度下降法。
输出层 → 隐藏层
从输出层到隐藏层是矩阵 W',要计算权重 w'ij 对 E 产生了多少影响,使用链式法则对 w'ij 求偏导:
因此,根据学习率 η∈(0,1) 调整 w'ij 为
假设 η = 0.4,举例来说,权重 w'14 = 0.86 和 w'24 = 0.67 更新为:
w'14 := w'14 - η*(e1,4+e2,4)*h1 = 0.86 - 0.4*0.314*0.65 = 0.78
w'24 := w'24 - η*(e1,4+e2,4)*h2 = 0.67 - 0.4*0.314*0.87 = 0.56
矩阵 W' 的所有权重都会更新,意味着所有单词的输出向量都会更新。
隐藏层 → 输入层
上面提到向量 h 其实不需要计算,进行查表即可,但为了求偏导,这里给出等效的计算公式:
从隐藏层到输入层是矩阵 W,要计算权重 wki 对 E 产生了多少影响,使用链式法则对 wki 求偏导:
因此调整 wki 为
由于向量 x 为独热编码向量,只有一个分量 xk = 1,其余均为 0,因此矩阵 W 只有第 k 行权重 wki 会更新,其余权重均保持不变。本例中,只更新权重 w21 = 0.65 和 w22 = 0.87,具体如下:
先求
∂E/∂w21 = (e1,1+e2,1)*w'11*x2 + (e1,2+e2,2)*w'12*x2 + ... + (e1,10+e2,10)*w'1,10*x2 = -0.794*0.05*1 + 0.111*0.65*1 + ... + 0.216*0.83*1 = 0.283
∂E/∂w22 = (e1,1+e2,1)*w'21*x2 + (e1,2+e2,2)*w'22*x2 + ... + (e2,10+e2,10)*w'1,10*x2 = -0.794*0.79*1 + 0.111*0.27*1 + ... + 0.216*0.26*1 = 0.081
再求
w21 := w21 - η*∂E/∂w21 = 0.65 - 0.4*0.283 = 0.54
w22 := w22 - η*∂E/∂w22 = 0.87 - 0.4*0.081 = 0.84
矩阵 W 只有一行权重会更新,即只有目标词的输入向量会更新。
优化计算效率
关于优化 Skip-gram 模型训练过程,使其计算复杂度在现实中可行,请阅读文档——Skip-gram 模型优化。