概述
GraphSAGE(SAmple and aggreGatE)是一种归纳式的学习框架,它不是为每个节点训练单独的嵌入,而是学习一个函数,该函数通过从节点的本地邻域采样和聚合特征来生成嵌入,于是可以为以前未见过的节点生成节点嵌入。GraphSAGE 是由斯坦福大学的 W.L. Hamilton 等人于 2017 年提出的。
算法的相关资料如下:
- W.L. Hamilton, R. Ying, J. Leskovec, Inductive Representation Learning on Large Graphs (2017)
基本概念
转导学习与归纳学习
传统的图嵌入算法(例如基于矩阵分解、随机游走的算法)在迭代的过程中需要用到所有节点的信息,学习得到所有节点的嵌入结果;如果有新加入的节点,则需要重新使用所有节点进行训练,这种转导式(Transductive,也译作直推式)学习框架泛化性差。
GraphSAGE 是一种归纳式(Inductive)的图嵌入算法,它的学习结果不是每个节点的嵌入,而是一系列聚合函数(Aggregator Function)。如果有新加入的节点,只要根据各节点的特征信息和结构信息,就可以得到新节点的嵌入,而不必整个重新迭代训练。GraphSAGE 的这种泛化性对于拥有高吞吐量的机器学习系统至关重要。
GraphSAGE 生成嵌入向量算法
假设我们已经学习好了 K 个用来聚合邻居节点信息的聚合函数(AGGREGATEk)以及 K 个用来在模型不同层级或搜索深度间传递信息的权重矩阵 Wk。训练过程的描述详见 GraphSAGE 训练。以下是 minibatch 环境中 GraphSAGE 生成嵌入向量(即前向传播)算法的主要步骤:
1. 邻域采样
- 设置在目标节点每层邻域固定的采样节点数 Si (i = 1,2,...,K)
- 对于每个的目标节点:按照设定的数量在每层邻域均匀(Uniform)采样,即所有节点被选中的概率相同
- 如果某层邻域节点数少于设定的数目,采取有放回的抽样方法,直到采样出规定数量的节点
固定大小的采样有利于将算法扩展到 minibatch 环境上,因为此时每批(batch)计算所占用的空间是固定的。GraphSAGE 原作者发现,K 不必取很大的值,K = 2 以及 S1·S2 < 500 时就能取得很好的效果。
以上图为例,当采样大小 sample_size
设定为 [5,3] 时,从目标节点 a 出发由内而外进行两轮采样,第一轮采样的所得的邻域集合为 N2(a) = {b,c,d},第二轮采样的所得的邻域集合为 N1(a) = {f,g,h,i,j}。
参数
sample_size
是一个整数数组,数组的长度即为最大搜索深度 K,数组内的元素依次是在目标节点的第 K 层、第 K-1 层、...、第 1 层邻域的采样个数。
采样顺序是从第 1 层依次到第 K 层,而邻域集合的下标顺序刚好相反(为 K, K-1, ..., 1),这是为了配合下一步从外向内聚合节点的特征信息。
2. 聚合邻居节点的特征信息
图 G = (V, E) 以及图中所有节点的特征向量 Xv (v∈V) 作为输入,特征向量由节点的若干属性组成。将 Xv 作为每个节点的初始向量表示:
算法每次的迭代(迭代次数 k = 1,2,...,K)过程如下:
-
对于所有的目标节点及其邻域集合(排除下标 ≥ k 的邻域集合)中的所有节点,将每个节点表示为 u 并进行如下计算:
- 使用聚合函数 AGGREGATEk 聚合 Nk(u) 中所有节点在第 k-1 轮迭代时的向量表示:
- 拼接上一步得到的向量与节点 u 在第 k-1 轮迭代时得到的向量,再经过一个非线性变换
σ
(比如 Sigmoid 函数)即得到节点 u 在本轮迭代的向量表示:
-
再对节点 u 在本轮迭代的向量表示进行 L2 归一化处理:
对于上面例子中的目标节点 a:
- 第一轮迭代中的计算:
- 第二轮迭代中的计算:
特殊处理
孤点、不连通图
孤点没有邻居节点,因此无法聚合其他任何节点的特征信息。
节点只能聚合处于同一连通分量的邻居节点的特征信息。
自环边
GraphSAGE 图嵌入算法忽略节点的自环边。
有向边
GraphSAGE 图嵌入算法的结果与边的方向无关。
命令和参数配置
- 命令:
algo(graph_sage)
params()
参数配置如下:
名称 | 类型 | 默认值 |
规范 |
描述 |
---|---|---|---|---|
model_name | string | / | / | 由 GraphSAGE 训练算法 algo(graph_sage_train) 训练后得到的模型名 |
model_task_id | int | / | / | 训练模型的算法任务 ID |
ids | []_id |
/ | / | 需要进行图嵌入的节点 ID;忽略表示全部点 |
node_property_names | []@<schema>?.<property> |
从模型读取 | 点属性,需LTE | 节点的特征属性 |
edge_property_name | @<schema>?.<property> |
从模型读取 | 数值类的边属性,需LTE | 边权重的一个或多个属性名称,带不带 schema 均可;随机游走过程中,节点只会沿着带有这些属性的边游走,且经过这些边的概率与边权重成正比;如果边带有多个指定属性,权重值为这些属性值的和;忽略表示所有边权重为 1 |
sample_size | []int | [25,10] | 从模型读取 | 数组的长度即为最大搜索深度 K,数组内的元素依次是在目标节点的第 K 层、第 K-1 层、...、第 1 层邻域的采样个数 |
算法执行
任务回写
1. 文件回写
算法不支持文件回写。
2. 属性回写
配置项 | 回写内容 | 类型 | 数据类型 |
---|---|---|---|
property_name | 节点的向量表示 | 点属性 | string |
示例:执行 GraphSAGE 算法,使用名为 model_1 的模型,将算法结果回写至名为 embedding 的点属性
algo(graph_sage).params({
model_task_id: 1,
model_name: "model_1"
}).write({
db:{
property_name: "embedding_graph_sage_model_1"
}
})
3. 统计回写
算法无统计值。
直接返回
算法不支持直接返回。
流式返回
算法不支持流式返回。
实时统计
算法无统计值。