Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add HGAT #213

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

[Model] Add HGAT #213

wants to merge 3 commits into from

Conversation

martinjingyu
Copy link

Add a new model: HGAT
corresponding with two conv layer, HINConv and HGATConv
and a new method to convert the Heterogeneous graph dict to homogeneous graph matrix: heter_homo_mutual_convert

@gyzhou2000 gyzhou2000 changed the title Add HGAT [Model] Add HGAT Jul 2, 2024
gammagl/utils/homo_heter_mutual_convert.py Outdated Show resolved Hide resolved
Comment on lines 51 to 62
# out_dict={}
# for node_type, _ in x_dict.items():
# out_dict[node_type]=[]
# for edge_type, edge_index in edge_index_dict.items():
# src_type, _, dst_type = edge_type
# src = edge_index[0,:]
# dst = edge_index[1,:]
# message = unsorted_segment_sum(tlx.gather(x_dict[src_type],src),dst,num_nodes_dict[dst_type])
# out_dict[dst_type].append(message)
# for node_type, outs in out_dict.items():
# aggr_out = tlx.reduce_sum(outs,axis=0)
# out_dict[node_type]=tlx.relu(aggr_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多余注释也删除

gammagl/layers/conv/hin_conv.py Outdated Show resolved Hide resolved
gammagl/layers/conv/hgat_conv.py Outdated Show resolved Hide resolved
gammagl/layers/conv/hgat_conv.py Outdated Show resolved Hide resolved
examples/hgat/hgat_trainer.py Outdated Show resolved Hide resolved
examples/hgat/hgat_trainer.py Outdated Show resolved Hide resolved
examples/hgat/hgat_trainer.py Outdated Show resolved Hide resolved
Comment on lines +40 to +79
data = HeteroGraph()

node_types = ['documents', 'topics', 'words']
for i, node_type in enumerate(node_types):
x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz'))
data[node_type].x = tlx.convert_to_tensor(x.todense(), dtype=tlx.float32)

y = np.load(osp.join(self.raw_dir, 'labels.npy'))
y = np.argmax(y,axis=1)
data['documents'].y = tlx.convert_to_tensor(y, dtype=tlx.int64)

split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz'))
for name in ['train', 'val', 'test']:
idx = split[f'{name}_idx']
mask = np.zeros(data['documents'].num_nodes, dtype=np.bool_)
mask[idx] = True
data['documents'][f'{name}_mask'] = tlx.convert_to_tensor(mask, dtype=tlx.bool)


s = {}
N_m = data['documents'].num_nodes
N_d = data['topics'].num_nodes
N_a = data['words'].num_nodes
s['documents'] = (0, N_m)
s['topics'] = (N_m, N_m + N_d)
s['words'] = (N_m + N_d, N_m + N_d + N_a)

A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')).tocsr()
for src, dst in product(node_types, node_types):
A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo()
if A_sub.nnz > 0:
row = tlx.convert_to_tensor(A_sub.row, dtype=tlx.int64)
col = tlx.convert_to_tensor(A_sub.col, dtype=tlx.int64)
data[src, dst].edge_index = tlx.stack([row, col], axis=0)
print(src+"____"+dst)

if self.pre_transform is not None:
data = self.pre_transform(data)

self.save_data(self.collate([data]), self.processed_paths[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改成从作者提供的url下载数据集,然后处理成heterograph

from gammagl.data import (HeteroGraph, InMemoryDataset, download_url,
extract_zip)

class OHSUMED(InMemoryDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写一个test文件,参考gammagl/datasets/路径下的其他测试文件

examples/hgat/hgat_trainer.py Outdated Show resolved Hide resolved
gammagl/models/hgat.py Outdated Show resolved Hide resolved
gammagl/models/hgat.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants