importtorchimportjsonfromtorchimporttensorfromdglimportDGLHeteroGraph,heterographfromnebula3.gclient.netimportConnectionPoolfromnebula3.ConfigimportConfigconfig=Config()config.max_connection_pool_size=2connection_pool=ConnectionPool()connection_pool.init([('graphd',9669)],config)vertex_id=2048client=connection_pool.get_session('root','nebula')r=client.execute_json("USE yelp;"f"GET SUBGRAPH WITH PROP 2 STEPS FROM {vertex_id} YIELD VERTICES AS nodes, EDGES AS relationships;")r=json.loads(r)data=r.get('results',[{}])[0].get('data')columns=r.get('results',[{}])[0].get('columns')# create node and nodedatanode_id_map={}# key: vertex id in NebulaGraph, value: node id in dgl_graphnode_idx=0features=[[]for_inrange(32)]+[[]]foriinrange(len(data)):forindex,nodeinenumerate(data[i]['meta'][0]):nodeid=data[i]['meta'][0][index]['id']ifnodeidnotinnode_id_map:node_id_map[nodeid]=node_idxnode_idx+=1forfinrange(32):features[f].append(data[i]['row'][0][index][f"review.f{f}"])features[32].append(data[i]['row'][0][index]['review.is_fraud'])rur_start,rur_end,rsr_start,rsr_end,rtr_start,rtr_end=[],[],[],[],[],[]foriinrange(len(data)):foredgeindata[i]['meta'][1]:edge=edge['id']ifedge['name']=='shares_user_with':rur_start.append(node_id_map[edge['src']])rur_end.append(node_id_map[edge['dst']])elifedge['name']=='shares_restaurant_rating_with':rsr_start.append(node_id_map[edge['src']])rsr_end.append(node_id_map[edge['dst']])elifedge['name']=='shares_restaurant_in_one_month_with':rtr_start.append(node_id_map[edge['src']])rtr_end.append(node_id_map[edge['dst']])data_dict={}ifrur_start:data_dict[('review','shares_user_with','review')]=tensor(rur_start),tensor(rur_end)ifrsr_start:data_dict[('review','shares_restaurant_rating_with','review')]=tensor(rsr_start),tensor(rsr_end)ifrtr_start:data_dict[('review','shares_restaurant_in_one_month_with','review')]=tensor(rtr_start),tensor(rtr_end)# construct a dgl_graph, ref: https://docs.dgl.ai/en/0.9.x/generated/dgl.heterograph.htmldgl_graph:DGLHeteroGraph=heterograph(data_dict)# load node features to dgl_graphdgl_graph.ndata['label']=tensor(features[32])# heterogeneous graph to heterogeneous graph, keep ndata and edataimportdglhg=dgl.to_homogeneous(dgl_graph,ndata=['label'])
fromabcimportabstractmethodimporttorchclassBaseLabelPropagation:"""Base class for label propagation models.
Parameters
----------
adj_matrix: torch.FloatTensor
Adjacency matrix of the graph.
"""def__init__(self,adj_matrix):self.norm_adj_matrix=self._normalize(adj_matrix)self.n_nodes=adj_matrix.size(0)self.one_hot_labels=Noneself.n_classes=Noneself.labeled_mask=Noneself.predictions=None@staticmethod@abstractmethoddef_normalize(adj_matrix):raiseNotImplementedError("_normalize must be implemented")@abstractmethoddef_propagate(self):raiseNotImplementedError("_propagate must be implemented")def_one_hot_encode(self,labels):# Get the number of classesclasses=torch.unique(labels)classes=classes[classes!=-1]self.n_classes=classes.size(0)# One-hot encode labeled data instances and zero rows corresponding to unlabeled instancesunlabeled_mask=(labels==-1)labels=labels.clone()# defensive copyinglabels[unlabeled_mask]=0self.one_hot_labels=torch.zeros((self.n_nodes,self.n_classes),dtype=torch.float)self.one_hot_labels=self.one_hot_labels.scatter(1,labels.unsqueeze(1),1)self.one_hot_labels[unlabeled_mask,0]=0self.labeled_mask=~unlabeled_maskdeffit(self,labels,max_iter,tol):"""Fits a semi-supervised learning label propagation model.
labels: torch.LongTensor
Tensor of size n_nodes indicating the class number of each node.
Unlabeled nodes are denoted with -1.
max_iter: int
Maximum number of iterations allowed.
tol: float
Convergence tolerance: threshold to consider the system at steady state.
"""self._one_hot_encode(labels)self.predictions=self.one_hot_labels.clone()prev_predictions=torch.zeros((self.n_nodes,self.n_classes),dtype=torch.float)foriinrange(max_iter):# Stop iterations if the system is considered at a steady statevariation=torch.abs(self.predictions-prev_predictions).sum().item()ifvariation<tol:print(f"The method stopped after {i} iterations, variation={variation:.4f}.")breakprev_predictions=self.predictionsself._propagate()defpredict(self):returnself.predictionsdefpredict_classes(self):returnself.predictions.max(dim=1).indicesclassLabelPropagation(BaseLabelPropagation):def__init__(self,adj_matrix):super().__init__(adj_matrix)@staticmethoddef_normalize(adj_matrix):"""Computes D^-1 * W"""degs=adj_matrix.sum(dim=1)degs[degs==0]=1# avoid division by 0 errorreturnadj_matrix/degs[:,None]def_propagate(self):self.predictions=torch.matmul(self.norm_adj_matrix,self.predictions)# Put back already known labelsself.predictions[self.labeled_mask]=self.one_hot_labels[self.labeled_mask]deffit(self,labels,max_iter=1000,tol=1e-3):super().fit(labels,max_iter,tol)classLabelSpreading(BaseLabelPropagation):def__init__(self,adj_matrix):super().__init__(adj_matrix)self.alpha=None@staticmethoddef_normalize(adj_matrix):"""Computes D^-1/2 * W * D^-1/2"""degs=adj_matrix.sum(dim=1)norm=torch.pow(degs,-0.5)norm[torch.isinf(norm)]=1returnadj_matrix*norm[:,None]*norm[None,:]def_propagate(self):self.predictions=(self.alpha*torch.matmul(self.norm_adj_matrix,self.predictions)+(1-self.alpha)*self.one_hot_labels)deffit(self,labels,max_iter=1000,tol=1e-3,alpha=0.5):"""
Parameters
----------
alpha: float
Clamping factor.
"""self.alpha=alphasuper().fit(labels,max_iter,tol)importpandasaspdimportnumpyasnpimportnetworkxasnximportmatplotlib.pyplotaspltnx_hg=hg.to_networkx()adj_matrix=nx.adjacency_matrix(nx_hg).toarray()labels=hg.ndata['label']# Create input tensorsadj_matrix_t=torch.FloatTensor(adj_matrix)labels_t=torch.LongTensor(labels)# Learn with Label Propagationlabel_propagation=LabelPropagation(adj_matrix_t)print("Label Propagation: ",end="")label_propagation.fit(labels_t)label_propagation_output_labels=label_propagation.predict_classes()# Learn with Label Spreadinglabel_spreading=LabelSpreading(adj_matrix_t)print("Label Spreading: ",end="")label_spreading.fit(labels_t,alpha=0.8)label_spreading_output_labels=label_spreading.predict_classes()
edge_types:- name:servestart_vertex_tag:playerend_vertex_tag:teamfeatures:- name:service_timeproperties:- name:start_yeartype:intnullable:False- name:end_yeartype:intnullable:False# The variable was mapped by order of propertiesfilter:type:functionfunction:"lambda start_year, end_year: (end_year - start_year) / 30"
枚举属性值为数字特征
这个例子中,我们把 team 顶点中的 name 属性进行枚举,根据这个对于是西岸还是东岸:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
vertex_tags:- name:teamfeatures:- name:coastproperties:- name:nametype:strnullable:Falsefilter:# 0 stand for the east coast, 1 stand for the west coasttype:enumerationenumeration:Celtics:0Nets:0Nuggets:1Timberwolves:1Thunder:1# ... not showing all teams here
---# If vertex id is string-typed, remap_vertex_id must be true.remap_vertex_id:Truespace:yelp# str or intvertex_id_type:intvertex_tags:- name:reviewlabel:name:is_fraudproperties:- name:is_fraudtype:intnullable:Falsefilter:type:valuefeatures:- name:f0properties:- name:f0type:floatnullable:Falsefilter:type:value- name:f1properties:- name:f1type:floatnullable:Falsefilter:type:value# ...- name:f31properties:- name:f31type:floatnullable:Falsefilter:type:valueedge_types:- name:shares_user_withstart_vertex_tag:reviewend_vertex_tag:review- name:shares_restaurant_rating_withstart_vertex_tag:reviewend_vertex_tag:review- name:shares_restaurant_in_one_month_withstart_vertex_tag:reviewend_vertex_tag:review
# Split the graph into train, validation, and test setsimportpandasaspdimportnumpyasnpfromsklearn.model_selectionimporttrain_test_split# features are g.ndata['f0'], g.ndata['f1'], g.ndata['f2'], ... g.ndata['f31']# label is in g.ndata['is_fraud']# concatenate all featuresfeatures=[]foriinrange(32):features.append(g.ndata['f'+str(i)])g.ndata['feat']=torch.stack(features,dim=1)g.ndata['label']=g.ndata['is_fraud']# numpy array as an index of range nidx=torch.tensor(np.arange(g.number_of_nodes()),device=device,dtype=torch.int64)# split based on value distribution of label: the property "is_fraud", which is a binary variable.X_train_and_val_idx,X_test_idx,y_train_and_val,y_test=train_test_split(idx,g.ndata['is_fraud'],test_size=0.2,random_state=42,stratify=g.ndata['is_fraud'])# split train and valX_train_idx,X_val_idx,y_train,y_val=train_test_split(X_train_and_val_idx,y_train_and_val,test_size=0.2,random_state=42,stratify=y_train_and_val)# list of index to masktrain_mask=torch.zeros(g.number_of_nodes(),dtype=torch.bool)train_mask[X_train_idx]=Trueval_mask=torch.zeros(g.number_of_nodes(),dtype=torch.bool)val_mask[X_val_idx]=Truetest_mask=torch.zeros(g.number_of_nodes(),dtype=torch.bool)test_mask[X_test_idx]=Trueg.ndata['train_mask']=train_maskg.ndata['val_mask']=val_maskg.ndata['test_mask']=test_mask
# three types of edgesIn[1]:g.etypesOut[1]:['shares_restaurant_in_one_month_with','shares_restaurant_rating_with','shares_user_with']In[2]:g.edges['shares_restaurant_in_one_month_with'].data['he']=torch.ones(g.number_of_edges('shares_restaurant_in_one_month_with'),dtype=torch.int64)g.edges['shares_restaurant_rating_with'].data['he']=torch.full((g.number_of_edges('shares_restaurant_rating_with'),),2,dtype=torch.int64)g.edges['shares_user_with'].data['he']=torch.full((g.number_of_edges('shares_user_with'),),4,dtype=torch.int64)In[3]:g.edata['he']Out[3]:{('review','shares_restaurant_in_one_month_with','review'):tensor([1,1,1,...,1,1,1]),('review','shares_restaurant_rating_with','review'):tensor([2,2,2,...,2,2,2]),('review','shares_user_with','review'):tensor([4,4,4,...,4,4,4])}
fromdglimportfunctionasfnfromdgl.utilsimportcheck_eq_shape,expand_as_pairclassSAGEConv(dglnn.SAGEConv):defforward(self,graph,feat,edge_weight=None):r"""
Description
-----------
Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N_{dst}, D_{out})`
where :math:`N_{dst}` is the number of destination nodes in the input graph,
:math:`D_{out}` is the size of the output feature.
"""self._compatibility_check()withgraph.local_scope():ifisinstance(feat,tuple):feat_src=self.feat_drop(feat[0])feat_dst=self.feat_drop(feat[1])else:feat_src=feat_dst=self.feat_drop(feat)ifgraph.is_block:feat_dst=feat_src[:graph.number_of_dst_nodes()]msg_fn=fn.copy_src('h','m')ifedge_weightisnotNone:assertedge_weight.shape[0]==graph.number_of_edges()graph.edata['_edge_weight']=edge_weightmsg_fn=fn.u_mul_e('h','_edge_weight','m')h_self=feat_dst# Handle the case of graphs without edgesifgraph.number_of_edges()==0:graph.dstdata['neigh']=torch.zeros(feat_dst.shape[0],self._in_src_feats).to(feat_dst)# Determine whether to apply linear transformation before message passing A(XW)lin_before_mp=self._in_src_feats>self._out_feats# Message Passingifself._aggre_type=='mean':graph.srcdata['h']=self.fc_neigh(feat_src)iflin_before_mpelsefeat_src# graph.update_all(msg_fn, fn.mean('m', 'neigh'))########################################################################## consdier datatype with different weight, g.edata['he'] as weight hereg.update_all(fn.u_mul_e('h','he','m'),fn.mean('m','h'))#########################################################################h_neigh=graph.dstdata['neigh']ifnotlin_before_mp:h_neigh=self.fc_neigh(h_neigh)elifself._aggre_type=='gcn':check_eq_shape(feat)graph.srcdata['h']=self.fc_neigh(feat_src)iflin_before_mpelsefeat_srcifisinstance(feat,tuple):# heterogeneousgraph.dstdata['h']=self.fc_neigh(feat_dst)iflin_before_mpelsefeat_dstelse:ifgraph.is_block:graph.dstdata['h']=graph.srcdata['h'][:graph.num_dst_nodes()]else:graph.dstdata['h']=graph.srcdata['h']graph.update_all(msg_fn,fn.sum('m','neigh'))graph.update_all(fn.copy_e('he','m'),fn.sum('m','neigh'))# divide in_degreesdegs=graph.in_degrees().to(feat_dst)h_neigh=(graph.dstdata['neigh']+graph.dstdata['h'])/(degs.unsqueeze(-1)+1)ifnotlin_before_mp:h_neigh=self.fc_neigh(h_neigh)elifself._aggre_type=='pool':graph.srcdata['h']=F.relu(self.fc_pool(feat_src))graph.update_all(msg_fn,fn.max('m','neigh'))graph.update_all(fn.copy_e('he','m'),fn.max('m','neigh'))h_neigh=self.fc_neigh(graph.dstdata['neigh'])elifself._aggre_type=='lstm':graph.srcdata['h']=feat_srcgraph.update_all(msg_fn,self._lstm_reducer)h_neigh=self.fc_neigh(graph.dstdata['neigh'])else:raiseKeyError('Aggregator type {} not recognized.'.format(self._aggre_type))# GraphSAGE GCN does not require fc_self.ifself._aggre_type=='gcn':rst=h_neighelse:rst=self.fc_self(h_self)+h_neigh# bias termifself.biasisnotNone:rst=rst+self.bias# activationifself.activationisnotNone:rst=self.activation(rst)# normalizationifself.normisnotNone:rst=self.norm(rst)returnrst
classSAGE(nn.Module):def__init__(self,in_size,hid_size,out_size):super().__init__()self.layers=nn.ModuleList()# three-layer GraphSAGE-meanself.layers.append(dglnn.SAGEConv(in_size,hid_size,'mean'))self.layers.append(dglnn.SAGEConv(hid_size,hid_size,'mean'))self.layers.append(dglnn.SAGEConv(hid_size,out_size,'mean'))self.dropout=nn.Dropout(0.5)self.hid_size=hid_sizeself.out_size=out_sizedefforward(self,blocks,x):h=xforl,(layer,block)inenumerate(zip(self.layers,blocks)):h=layer(block,h)ifl!=len(self.layers)-1:h=F.relu(h)h=self.dropout(h)returnhdefinference(self,g,device,batch_size):"""Conduct layer-wise inference to get all the node embeddings."""feat=g.ndata['feat']sampler=MultiLayerFullNeighborSampler(1,prefetch_node_feats=['feat'])dataloader=DataLoader(g,torch.arange(g.num_nodes()).to(g.device),sampler,device=device,batch_size=batch_size,shuffle=False,drop_last=False,num_workers=0)buffer_device=torch.device('cpu')pin_memory=(buffer_device!=device)forl,layerinenumerate(self.layers):y=torch.empty(g.num_nodes(),self.hid_sizeifl!=len(self.layers)-1elseself.out_size,device=buffer_device,pin_memory=pin_memory)feat=feat.to(device)forinput_nodes,output_nodes,blocksintqdm.tqdm(dataloader):x=feat[input_nodes]h=layer(blocks[0],x)# len(blocks) = 1ifl!=len(self.layers)-1:h=F.relu(h)h=self.dropout(h)# by design, our output nodes are contiguousy[output_nodes[0]:output_nodes[-1]+1]=h.to(buffer_device)feat=yreturny
fromnebula_dglimportNebulaLoadernebula_config={"graph_hosts":[('graphd',9669),('graphd1',9669),('graphd2',9669)],"nebula_user":"root","nebula_password":"nebula",}withopen('nebulagraph_yelp_dgl_mapper.yaml','r')asf:feature_mapper=yaml.safe_load(f)nebula_loader=NebulaLoader(nebula_config,feature_mapper)g=nebula_loader.load()# This will take you some time# 作为穷人,我们用 CPUg=g.to('cpu')device=torch.device('cpu')
# Split the graph into train, validation and test setsimportpandasaspdimportnumpyasnpfromsklearn.model_selectionimporttrain_test_split# features are g.ndata['f0'], g.ndata['f1'], g.ndata['f2'], ... g.ndata['f31']# label is in g.ndata['is_fraud']# concatenate all featuresfeatures=[]foriinrange(32):features.append(g.ndata['f'+str(i)])g.ndata['feat']=torch.stack(features,dim=1)g.ndata['label']=g.ndata['is_fraud']# numpy array as index of range nidx=torch.tensor(np.arange(g.number_of_nodes()),device=device,dtype=torch.int64)# features.append(idx)# concatenate one dim with index of node# feature_and_idx = torch.stack(features, dim=1)# split based on value distribution of label: the property "is_fraud", which is a binary variable.X_train_and_val_idx,X_test_idx,y_train_and_val,y_test=train_test_split(idx,g.ndata['is_fraud'],test_size=0.2,random_state=42,stratify=g.ndata['is_fraud'])# split train and valX_train_idx,X_val_idx,y_train,y_val=train_test_split(X_train_and_val_idx,y_train_and_val,test_size=0.2,random_state=42,stratify=y_train_and_val)# list of index to masktrain_mask=torch.zeros(g.number_of_nodes(),dtype=torch.bool)train_mask[X_train_idx]=Trueval_mask=torch.zeros(g.number_of_nodes(),dtype=torch.bool)val_mask[X_val_idx]=Truetest_mask=torch.zeros(g.number_of_nodes(),dtype=torch.bool)test_mask[X_test_idx]=Trueg.ndata['train_mask']=train_maskg.ndata['val_mask']=val_maskg.ndata['test_mask']=test_mask# shares_restaurant_in_one_month_with: 1, b"001"# shares_restaurant_rating_with: 2, b"010"# shares_user_with: 4, b"100"# set edata of shares_restaurant_in_one_month_with to n of 1 tensor arrayg.edges['shares_restaurant_in_one_month_with'].data['he']=torch.ones(g.number_of_edges('shares_restaurant_in_one_month_with'),dtype=torch.float32)g.edges['shares_restaurant_rating_with'].data['he']=torch.full((g.number_of_edges('shares_restaurant_rating_with'),),2,dtype=torch.float32)g.edges['shares_user_with'].data['he']=torch.full((g.number_of_edges('shares_user_with'),),4,dtype=torch.float32)# heterogeneous graph to heterogeneous graph, keep ndata and edatahg=dgl.to_homogeneous(g,edata=['he'],ndata=['feat','label','train_mask','val_mask','test_mask'])
训练、测试模型!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# create GraphSAGE modelin_size=hg.ndata['feat'].shape[1]out_size=2model=SAGE(in_size,256,out_size).to(device)# model trainingprint('Training...')train(device,hg,model,X_train_idx,X_val_idx)# test the modelprint('Testing...')acc=layerwise_infer(device,hg,X_test_idx,model,batch_size=4096)print("Test Accuracy {:.4f}".format(acc.item()))# 运行结果# Test Accuracy 0.9996
# save modeltorch.save(model.state_dict(),"fraud_d.model")# load modeldevice=torch.device('cpu')model=SAGE(32,256,2).to(device)model.load_state_dict(torch.load("fraud_d.model"))
# Inductive Learning, our test dataset are new nodes and new edgeshg_train=hg.subgraph(torch.cat([X_train_idx,X_val_idx]))# model trainingprint('Training...')train(device,hg_train,model,torch.arange(X_train_idx.shape[0]),torch.arange(X_train_idx.shape[0],hg_train.num_nodes()))# test the modelprint('Testing...')hg_test=hg.subgraph(torch.cat([X_test_idx]))sg_X_test_idx=torch.arange(hg_test.num_nodes())acc=layerwise_infer(device,hg_test,sg_X_test_idx,model,batch_size=4096)print("Test Accuracy {:.4f}".format(acc.item()))# 运行结果# Test Accuracy 0.9990
可以看到返回的结果其实还是很多的,不过对于 NebulaGraph 来说,这个子图结果返回是在 10 ms 左右获取的,这里我就不贴出来了,如果我们在 NebulaGraph Studio 或者 Explorer 中可以把结果渲染出来(可视化展示的 Query 可以去掉 WITH PROP ,可以给浏览器省点内存),结果就更容易让人脑理解了:
# get SUBGRAPH of one nodeimportjsonfromtorchimporttensorfromdglimportDGLHeteroGraph,heterographfromnebula3.gclient.netimportConnectionPoolfromnebula3.ConfigimportConfigconfig=Config()config.max_connection_pool_size=2connection_pool=ConnectionPool()connection_pool.init([('graphd',9669)],config)vertex_id=2048client=connection_pool.get_session('root','nebula')r=client.execute_json("USE yelp;"f"GET SUBGRAPH WITH PROP 2 STEPS FROM {vertex_id} YIELD VERTICES AS nodes, EDGES AS relationships;")r=json.loads(r)data=r.get('results',[{}])[0].get('data')
# create node and nodedatanode_id_map={}# key: vertex id in NebulaGraph, value: node id in dgl_graphnode_idx=0features=[[]for_inrange(32)]+[[]]foriinrange(len(data)):forindex,nodeinenumerate(data[i]['meta'][0]):nodeid=data[i]['meta'][0][index]['id']ifnodeidnotinnode_id_map:node_id_map[nodeid]=node_idxnode_idx+=1forfinrange(32):features[f].append(data[i]['row'][0][index][f"review.f{f}"])features[32].append(data[i]['row'][0][index]['review.is_fraud'])"""
- R-U-R:两个评价由同一个用户发出 shares_user_with
- R-S-R:两个评价是同餐厅同评分(评分可以是1到5) shares_restaurant_rating_with
- R-T-R:两个评价是同餐厅同提交月份 shares_restaurant_in_one_month_with
"""rur_start,rur_end,rsr_start,rsr_end,rtr_start,rtr_end=[],[],[],[],[],[]foriinrange(len(data)):foredgeindata[i]['meta'][1]:edge=edge['id']ifedge['name']=='shares_user_with':rur_start.append(node_id_map[edge['src']])rur_end.append(node_id_map[edge['dst']])elifedge['name']=='shares_restaurant_rating_with':rsr_start.append(node_id_map[edge['src']])rsr_end.append(node_id_map[edge['dst']])elifedge['name']=='shares_restaurant_in_one_month_with':rtr_start.append(node_id_map[edge['src']])rtr_end.append(node_id_map[edge['dst']])data_dict={}ifrur_start:data_dict[('review','shares_user_with','review')]=tensor(rur_start),tensor(rur_end)ifrsr_start:data_dict[('review','shares_restaurant_rating_with','review')]=tensor(rsr_start),tensor(rsr_end)ifrtr_start:data_dict[('review','shares_restaurant_in_one_month_with','review')]=tensor(rtr_start),tensor(rtr_end)# construct a dgl_graphdgl_graph:DGLHeteroGraph=heterograph(data_dict)
importtorch# to homogeneous graphfeatures=[]foriinrange(32):features.append(dgl_graph.ndata[f"f{i}"])dgl_graph.ndata['feat']=torch.stack(features,dim=1)dgl_graph.edges['shares_restaurant_in_one_month_with'].data['he']=torch.ones(dgl_graph.number_of_edges('shares_restaurant_in_one_month_with'),dtype=torch.float32)dgl_graph.edges['shares_restaurant_rating_with'].data['he']=torch.full((dgl_graph.number_of_edges('shares_restaurant_rating_with'),),2,dtype=torch.float32)dgl_graph.edges['shares_user_with'].data['he']=torch.full((dgl_graph.number_of_edges('shares_user_with'),),4,dtype=torch.float32)# heterogeneous graph to heterogeneous graph, keep ndata and edataimportdglhg=dgl.to_homogeneous(dgl_graph,edata=['he'],ndata=['feat','label'])
最后,我们的推理接口就是:
1
2
3
4
5
defdo_inference(device,graph,node_idx,model,batch_size):model.eval()withtorch.no_grad():pred=model.inference(graph,device,batch_size)# pred in buffer_devicereturnpred[node_idx]
deftest_inference(device,graph,nid,model,batch_size):model.eval()withtorch.no_grad():pred=model.inference(graph,device,batch_size)# pred in buffer_devicepred=pred[nid]label=graph.ndata['label'][nid].to(pred.device)returnMF.accuracy(pred,label)node_idx=torch.tensor(list(node_id_map.values()))acc=test_inference(device,hg,node_idx,model,batch_size=4096)print("Test Accuracy {:.4f}".format(acc.item()))
输出结果:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
In[307]:deftest_inference(device,graph,nid,model,batch_size):...:model.eval()...:withtorch.no_grad():...:pred=model.inference(graph,device,batch_size)# pred in buffer...:_device...:pred=pred[nid]...:label=graph.ndata['label'][nid].to(pred.device)...:returnMF.accuracy(pred,label)...:...:node_idx=torch.tensor(list(node_id_map.values()))...:acc=test_inference(device,hg,node_idx,model,batch_size=4096)...:print("Test Accuracy {:.4f}".format(acc.item()))...:100%|████████████████████████████████████████████████|1/1[00:00<00:00,130.31it/s]100%|████████████████████████████████████████████████|1/1[00:00<00:00,152.29it/s]100%|████████████████████████████████████████████████|1/1[00:00<00:00,173.55it/s]TestAccuracy0.9688