Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect handling of node ids by MetaPath2Vec #9928

Open
dgattiwsu opened this issue Jan 9, 2025 · 0 comments
Open

Incorrect handling of node ids by MetaPath2Vec #9928

dgattiwsu opened this issue Jan 9, 2025 · 0 comments
Labels

Comments

@dgattiwsu
Copy link

🐛 Describe the bug

I am encountering an issue with the MetaPath2Vec model, where the node IDs are incorrectly handled, leading to out-of-range node indices in the random walks. The model parameters indicate an incorrect number of nodes for each type. The following code generates the heterograph, the MetaPath2vec model, and the loader:

import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import MetaPath2Vec

# Example data
data = HeteroData()
data['patient'].x = torch.randn(3, 3)
data['patient'].y = torch.tensor([2])
data['patient'].y_index = torch.tensor([2])
data['disorder'].x = torch.randn(3, 1)
data['drug'].x = torch.randn(4, 2)
data['test'].x = torch.randn(4, 1)
data['symptom'].x = torch.randn(3, 1)
data['outcome'].x = torch.randn(2, 1)

data['patient', 'was_diagnosed', 'disorder'].edge_index = torch.tensor([[0, 0, 1, 2], [3, 4, 5, 3]])
data['patient', 'was_prescribed', 'drug'].edge_index = torch.tensor([[0, 1, 1, 2], [8, 6, 7, 9]])
data['patient', 'had', 'test'].edge_index = torch.tensor([[0, 1, 1, 2], [11, 10, 12, 13]])
data['patient', 'experienced', 'symptom'].edge_index = torch.tensor([[0, 0, 0, 1, 2], [14, 15, 16, 14, 15]])
data['patient', 'resulted_in', 'outcome'].edge_index = torch.tensor([[1, 2], [18, 17]])

data['disorder', 'associates_with', 'symptom'].edge_index = torch.tensor([[3, 3], [15, 16]])
data['drug', 'affects', 'symptom'].edge_index = torch.tensor([[6, 8, 8, 9, 9], [14, 15, 16, 15, 16]])
data['disorder', 'rev_was_diagnosed', 'patient'].edge_index = torch.tensor([[3, 4, 5, 3], [0, 0, 1, 2]])
data['drug', 'rev_was_prescribed', 'patient'].edge_index = torch.tensor([[8, 6, 7, 9], [0, 1, 1, 2]])
data['test', 'rev_had', 'patient'].edge_index = torch.tensor([[11, 10, 12, 13], [0, 1, 1, 2]])
data['symptom', 'rev_experienced', 'patient'].edge_index = torch.tensor([[14, 15, 16, 14, 15], [0, 0, 0, 1, 2]])
data['outcome', 'rev_resulted_in', 'patient'].edge_index = torch.tensor([[18, 17], [1, 2]])
data['symptom', 'rev_associates_with', 'disorder'].edge_index = torch.tensor([[15, 16], [3, 3]])
data['symptom', 'rev_affects', 'drug'].edge_index = torch.tensor([[14, 15, 16, 15, 16], [6, 8, 8, 9, 9]])

# Define node names for each type
symptom_names = ['Headache', 'Polyuria', 'Polydipsia']
drug_names = ['Ibuprofen', 'Ciprofloxacin', 'Metformin', 'Insulin']
disorder_names = ['Diabetes', 'Hypertension', 'Cancer']
test_names = ['Blood Test', 'Ultrasound', 'Urine Test', 'ECG']
outcome_names = ['Term Delivery', 'Preterm Delivery']
patient_ids = [f'Patient_{i}' for i in range(3)]

# Create a mapping from node IDs to names with global indices
global_counter = 0
node_name_mapping = {}

def add_to_global_mapping(names, node_type):
    global global_counter, node_name_mapping
    node_name_mapping[node_type] = {}
    for name in names:
        node_name_mapping[node_type][global_counter] = name
        global_counter += 1

add_to_global_mapping(patient_ids, 'patient')
add_to_global_mapping(disorder_names, 'disorder')
add_to_global_mapping(drug_names, 'drug')
add_to_global_mapping(test_names, 'test')
add_to_global_mapping(symptom_names, 'symptom')
add_to_global_mapping(outcome_names, 'outcome')

# Verify the total number of nodes
total_nodes = sum(len(mapping) for mapping in node_name_mapping.values())
print(f"Total number of nodes: {total_nodes}")

# Define the metapath
meta_path = [
    ('patient', 'experienced', 'symptom'),
    ('symptom', 'rev_affects', 'drug'),
    ('drug', 'rev_was_prescribed', 'patient') 
]

# Initialize the MetaPath2Vec model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MetaPath2Vec(data.edge_index_dict,
                     embedding_dim=16,
                     metapath=meta_path,
                     walk_length=6,
                     context_size=5,
                     walks_per_node=5,
                     num_negative_samples=1,
                     sparse=True).to(device)

# Check the MetaPath2Vec model parameters
print("MetaPath2Vec model parameters:")
print(model)

# Create the loader
loader = model.loader(batch_size=1, shuffle=True, num_workers=2)

# Check the random walks
for idx, (pos_rw, neg_rw) in enumerate(loader):
    if idx == 10: break
    print(idx, pos_rw.shape, neg_rw.shape)

# Print the first 15 positive and negative random walks
for i in range(15):
    print(f"Positive walk {i}: {pos_rw[i]}")
    print(f"Negative walk {i}: {neg_rw[i]}")

# Verify that the random walks are within the valid range of node indices
max_node_index = total_nodes - 1
for i in range(pos_rw.size(0)):
    if pos_rw[i].max().item() > max_node_index or neg_rw[i].max().item() > max_node_index:
        print(f"Error: Node index out of range in walk {i}")

# Check the edge index dictionary
print("Edge index dictionary:")
for key, value in data.edge_index_dict.items():
    print(f"{key}: {value}")

# Inspect the internal state of MetaPath2Vec
print("MetaPath2Vec internal state:")
print(f"Number of nodes: {model.num_nodes_dict}")

# Verify the node ID mapping
print("Node ID mapping:")
for node_type, mapping in node_name_mapping.items():
    print(f"{node_type}: {mapping}")

The following is the printout from the code:

Total number of nodes: 19
MetaPath2Vec model parameters:
MetaPath2Vec(30, 16)
0 torch.Size([15, 5]) torch.Size([15, 5])
1 torch.Size([15, 5]) torch.Size([15, 5])
2 torch.Size([15, 5]) torch.Size([15, 5])
Positive walk 0: tensor([11, 27,  6, 11, 27])
Negative walk 0: tensor([11, 19,  8, 10, 24])
Positive walk 1: tensor([11, 27,  6, 11, 27])
Negative walk 1: tensor([11, 13,  3, 12, 23])
Positive walk 2: tensor([11, 27,  6, 11, 27])
Negative walk 2: tensor([11, 15,  3, 12, 14])
Positive walk 3: tensor([11, 27,  6, 11, 27])
Negative walk 3: tensor([11, 17,  7, 12, 23])
Positive walk 4: tensor([11, 27,  6, 11, 27])
Negative walk 4: tensor([11, 21,  9, 10, 20])
Positive walk 5: tensor([27,  6, 11, 27,  6])
Negative walk 5: tensor([19,  8, 10, 24,  0])
Positive walk 6: tensor([27,  6, 11, 27,  6])
Negative walk 6: tensor([13,  3, 12, 23,  1])
Positive walk 7: tensor([27,  6, 11, 27,  6])
Negative walk 7: tensor([15,  3, 12, 14,  7])
Positive walk 8: tensor([27,  6, 11, 27,  6])
Negative walk 8: tensor([17,  7, 12, 23,  7])
Positive walk 9: tensor([27,  6, 11, 27,  6])
Negative walk 9: tensor([21,  9, 10, 20,  3])
Positive walk 10: tensor([ 6, 11, 27,  6, 11])
Negative walk 10: tensor([ 8, 10, 24,  0, 10])
Positive walk 11: tensor([ 6, 11, 27,  6, 11])
Negative walk 11: tensor([ 3, 12, 23,  1, 11])
Positive walk 12: tensor([ 6, 11, 27,  6, 11])
Negative walk 12: tensor([ 3, 12, 14,  7, 11])
Positive walk 13: tensor([ 6, 11, 27,  6, 11])
Negative walk 13: tensor([ 7, 12, 23,  7, 10])
Positive walk 14: tensor([ 6, 11, 27,  6, 11])
Negative walk 14: tensor([ 9, 10, 20,  3, 12])
Error: Node index out of range in walk 0
Error: Node index out of range in walk 1
Error: Node index out of range in walk 2
Error: Node index out of range in walk 3
Error: Node index out of range in walk 4
Error: Node index out of range in walk 5
Error: Node index out of range in walk 6
Error: Node index out of range in walk 7
Error: Node index out of range in walk 8
Error: Node index out of range in walk 9
Error: Node index out of range in walk 10
Error: Node index out of range in walk 11
Error: Node index out of range in walk 12
Error: Node index out of range in walk 13
Error: Node index out of range in walk 14
Edge index dictionary:
('patient', 'was_diagnosed', 'disorder'): tensor([[0, 0, 1, 2],
        [3, 4, 5, 3]])
('patient', 'was_prescribed', 'drug'): tensor([[0, 1, 1, 2],
        [8, 6, 7, 9]])
('patient', 'had', 'test'): tensor([[ 0,  1,  1,  2],
        [11, 10, 12, 13]])
('patient', 'experienced', 'symptom'): tensor([[ 0,  0,  0,  1,  2],
        [14, 15, 16, 14, 15]])
('patient', 'resulted_in', 'outcome'): tensor([[ 1,  2],
        [18, 17]])
('disorder', 'associates_with', 'symptom'): tensor([[ 3,  3],
        [15, 16]])
('drug', 'affects', 'symptom'): tensor([[ 6,  8,  8,  9,  9],
        [14, 15, 16, 15, 16]])
('disorder', 'rev_was_diagnosed', 'patient'): tensor([[3, 4, 5, 3],
        [0, 0, 1, 2]])
('drug', 'rev_was_prescribed', 'patient'): tensor([[8, 6, 7, 9],
        [0, 1, 1, 2]])
('test', 'rev_had', 'patient'): tensor([[11, 10, 12, 13],
        [ 0,  1,  1,  2]])
('symptom', 'rev_experienced', 'patient'): tensor([[14, 15, 16, 14, 15],
        [ 0,  0,  0,  1,  2]])
('outcome', 'rev_resulted_in', 'patient'): tensor([[18, 17],
        [ 1,  2]])
('symptom', 'rev_associates_with', 'disorder'): tensor([[15, 16],
        [ 3,  3]])
('symptom', 'rev_affects', 'drug'): tensor([[14, 15, 16, 15, 16],
        [ 6,  8,  8,  9,  9]])
MetaPath2Vec internal state:
Number of nodes: {'patient': 3, 'disorder': 6, 'drug': 10, 'test': 14, 'symptom': 17, 'outcome': 19}
Node ID mapping:
patient: {0: 'Patient_0', 1: 'Patient_1', 2: 'Patient_2'}
disorder: {3: 'Diabetes', 4: 'Hypertension', 5: 'Cancer'}
drug: {6: 'Ibuprofen', 7: 'Ciprofloxacin', 8: 'Metformin', 9: 'Insulin'}
test: {10: 'Blood Test', 11: 'Ultrasound', 12: 'Urine Test', 13: 'ECG'}
symptom: {14: 'Headache', 15: 'Polyuria', 16: 'Polydipsia'}
outcome: {17: 'Term Delivery', 18: 'Preterm Delivery'}

As you can see, the node ids in the random walks exceed the total number of nodes in the graph although the Edge index dictionary shows the correct number of nodes and edges. In addition to the issues with node IDs, the negative random walks generated by the MetaPath2Vec model do not always start from the same node as the corresponding positive random walks. Based on my understanding, the negative random walks should start from the same node to provide meaningful negative samples for contrastive learning.

Versions

PyTorch version: 2.2.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.7.6 (x86_64)
GCC version: Could not collect
Clang version: 18.1.8
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.11 (main, Dec 11 2024, 10:28:39) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Intel(R) Core(TM) i5-5287U CPU @ 2.90GHz

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.5.0.post0
[pip3] torch==2.2.2
[pip3] torch_cluster==1.6.3
[pip3] torch-geometric==2.6.1
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch_spline_conv==1.2.2
[pip3] torchaudio==2.2.2
[pip3] torchdata==0.10.1
[pip3] torchmetrics==1.6.1
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.17.2
[conda] numpy 1.26.4 pypi_0 pypi
[conda] pytorch-lightning 2.5.0.post0 pypi_0 pypi
[conda] torch 2.2.2 pypi_0 pypi
[conda] torch-cluster 1.6.3 pypi_0 pypi
[conda] torch-geometric 2.6.1 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torch-sparse 0.6.18 pypi_0 pypi
[conda] torch-spline-conv 1.2.2 pypi_0 pypi
[conda] torchaudio 2.2.2 pypi_0 pypi
[conda] torchdata 0.10.1 pypi_0 pypi
[conda] torchmetrics 1.6.1 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.17.2 pypi_0 pypi

@dgattiwsu dgattiwsu added the bug label Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant