概述
GraphSAGE 训练算法得出的模型可用于 GraphSAGE 图嵌入算法。GraphSAGE 是由斯坦福大学的 W.H Hamilton 等人于 2017 年提出的。
算法的相关资料如下:
- W.L. Hamilton, R. Ying, J. Leskovec, Inductive Representation Learning on Large Graphs (2017)
基本概念
GraphSAGE 训练
GraphSAGE 模型参数是通过标准的随机梯度下降(SGD)以及反向传播算法学习的。
GraphSAGE 原作者直接提出了如下基于图的损失函数,通过它调整模型参数能使得在图上相距较近的节点有相似的向量表示,同时相距较远的节点的向量表示区别很大:
算法采用随机游走的方式确定“邻近”关系,即从节点 u 开始进行固定长度的随机游走,节点 v 是在随机游走序列中出现的节点,相当于正样本。Zu、Zv 分别是节点 u、v 的嵌入向量表示,它们的内积越大则表示越相似,σ
是 Sigmoid 激活函数。Q 是负样本数量,vn 表示负样本,负样本的概率分布是 Pn。
关于负采样,读者可参考文档——Skip-gram 模型优化。
如果有特定的下游任务,可根据任务目标使用其他的损失函数,比如交叉熵损失函数。
聚合函数
聚合器的作用是把一个向量集合转换成一个向量。和其他机器学习任务中的数据(如句子、图像等)不同,节点的邻居节点没有特定的顺序,因此聚合函数需要能处理无序的向量集合并且具有对称性(即改变输入顺序,输出结果不变),同时具有较高的表达能力。Ultipa 的 GraphSAGE 训练算法提供以下两种聚合器:
1. 均值聚合器
均值聚合器(Mean Aggregator)直接取各向量的元素均值得到聚合向量。使用此聚合器时,嵌入生成算法直接产生节点的第 k 轮迭代的向量表示:
例如,三个向量分别为 [1,2]、[4,3]、[3,4],均值聚合后的向量为 [2.667,3]。
2. 池化聚合器
在池化聚合器(Pooling Aggregator)中,每个邻居的向量先单独通过一个完全连接的神经网络;经过此转换之后,再对每个元素应用最大池化操作来聚合邻域集合的信息:
其中 max 代表按元素取最大值的操作,σ
是一个非线性的激活函数。
命令和参数配置
- 命令:
algo(graph_sage_train)
params()
参数配置如下:
名称 | 类型 | 默认值 |
规范 |
描述 |
---|---|---|---|---|
dimension | int | 64 | >0 | 生成的节点嵌入向量的维度,也是各隐藏层向量的维度,同时也是点属性的数量 |
node_property_names | []@<schema>?.<property> |
/ | 点属性,需LTE | 节点的特征属性 |
edge_property_name | @<schema>?.<property> |
/ | 数值类的边属性,需LTE | 边权重的一个或多个属性名称,带不带 schema 均可;随机游走过程中,节点只会沿着带有这些属性的边游走,且经过这些边的概率与边权重成正比;如果边带有多个指定属性,权重值为这些属性值的和;忽略表示所有边权重为 1 |
search_depth | int | 5 | >0 | 随机游走的深度 |
sample_size | []int | [25,10] | >0 | 数组的长度即为最大搜索深度 K,数组内的元素依次是在目标节点的第 K 层、第 K-1 层、...、第 1 层邻域的采样个数 |
learning_rate | float | 0.1 | [0, 1] | 学习率 |
epochs | int | 10 | >0 | 图遍历次数,即大循环的次数;每轮大循环前,会重新进行邻域采样 |
max_iterations | int | 10 | >0 | 每轮大循环中的迭代次数;每轮迭代使用随机选择一批(batch)节点的梯度更新权重 |
tolerance | double | 1e-10 | >0 | 收敛标准;每轮迭代结束时,计算本轮传播与上一轮传播的损失的差值,差值小于收敛标准则提前结束计算 |
aggregator | string | mean | mean 或 pool | 聚合器类型,mean 代表均值聚合,pool 代表池化聚合 |
batch_size | int | 节点数/线程数 | >0 | 每批(batch)计算的节点数量,也是负采样的数量 |
算法执行
任务回写
1. 文件回写
配置项 | 各列数据 | 描述 | 格式 |
---|---|---|---|
model_name | / | 训练后的模型 | JSON |
示例:执行 GraphSAGE 训练算法,以属性 age、hot 作为节点的特征属性,将算法结果回写至名为 model_1 的模型
algo(graph_sage_train).params({
node_property_names: ['age','hot']
}).write({
file:{
model_name: "model_1"
}
})
2. 属性回写
算法不支持属性回写。
3. 统计回写
算法无统计值。
直接返回
算法不支持直接返回。
流式返回
算法不支持流式返回。
实时统计
算法无统计值。