Skip to content

Commit

Permalink
Allow avg_degree in FakeDataset to be float instead of int (pyg-t…
Browse files Browse the repository at this point in the history
…eam#8404)

I needed to use the `FakeDataset` for checking something in a randomly
generated graph that was identical in "size" to a real graph
(cora/citeseer), in which the `avg_degree` is usually not an integer.

So, I decided to extend the `FakeDataset` and `FakeHeteroDataset` to
float.
I've tried to make sure other things don't mess up because of this by
using `int(avg_degree)` where I felt it was necessary.

Also, I noticed a small thing: inside `FakeDataset` L62 and L63 are
switched, and inside `FakeHeteroDataset` L174 and L175 are switched,
whereas it should be checking whether `avg_degree` is < 1 before using
it to find `avg_num_nodes`. It wont make a difference anyway, the only
time that would matter is when `avg_degree` changed after using it's
value, which will happen only when it is < 1, in which case the previous
`max(avg_num_nodes, avg_degree)` should come out to be the former, which
_should_ be >= 1.
There are no checks for any of this though, if needed, I would be glad
to add those checks.

---------

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
plutonium-239 and rusty1s authored Nov 20, 2023
1 parent 362fe90 commit c7483ac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for floating-point average degree numbers in `FakeDataset` and `FakeHeteroDataset` ([#8404](https://github.com/pyg-team/pytorch_geometric/pull/8404))
- Added support for device conversions of `InMemoryDataset` ([#8402] (https://github.com/pyg-team/pytorch_geometric/pull/8402))
- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372))
- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
Expand Down
34 changes: 20 additions & 14 deletions torch_geometric/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Optional

import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData, InMemoryDataset
from torch_geometric.utils import coalesce, remove_self_loops, to_undirected
Expand All @@ -17,8 +18,8 @@ class FakeDataset(InMemoryDataset):
num_graphs (int, optional): The number of graphs. (default: :obj:`1`)
avg_num_nodes (int, optional): The average number of nodes in a graph.
(default: :obj:`1000`)
avg_degree (int, optional): The average degree per node.
(default: :obj:`10`)
avg_degree (float, optional): The average degree per node.
(default: :obj:`10.0`)
num_channels (int, optional): The number of node features.
(default: :obj:`64`)
edge_dim (int, optional): The number of edge features.
Expand All @@ -43,7 +44,7 @@ def __init__(
self,
num_graphs: int = 1,
avg_num_nodes: int = 1000,
avg_degree: int = 10,
avg_degree: float = 10.0,
num_channels: int = 64,
edge_dim: int = 0,
num_classes: int = 10,
Expand All @@ -59,7 +60,7 @@ def __init__(
task = 'graph' if num_graphs > 1 else 'node'
assert task in ['node', 'graph']

self.avg_num_nodes = max(avg_num_nodes, avg_degree)
self.avg_num_nodes = max(avg_num_nodes, int(avg_degree))
self.avg_degree = max(avg_degree, 1)
self.num_channels = num_channels
self.edge_dim = edge_dim
Expand Down Expand Up @@ -115,8 +116,8 @@ class FakeHeteroDataset(InMemoryDataset):
(default: :obj:`6`)
avg_num_nodes (int, optional): The average number of nodes in a graph.
(default: :obj:`1000`)
avg_degree (int, optional): The average degree per node.
(default: :obj:`10`)
avg_degree (float, optional): The average degree per node.
(default: :obj:`10.0`)
avg_num_channels (int, optional): The average number of node features.
(default: :obj:`64`)
edge_dim (int, optional): The number of edge features.
Expand All @@ -141,7 +142,7 @@ def __init__(
num_node_types: int = 3,
num_edge_types: int = 6,
avg_num_nodes: int = 1000,
avg_degree: int = 10,
avg_degree: float = 10.0,
avg_num_channels: int = 64,
edge_dim: int = 0,
num_classes: int = 10,
Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(
count[edge_type] += 1
self.edge_types.append((edge_type[0], rel, edge_type[1]))

self.avg_num_nodes = max(avg_num_nodes, avg_degree)
self.avg_num_nodes = max(avg_num_nodes, int(avg_degree))
self.avg_degree = max(avg_degree, 1)
self.num_channels = [
get_num_channels(avg_num_channels) for _ in self.node_types
Expand Down Expand Up @@ -231,8 +232,8 @@ def generate_data(self) -> HeteroData:
###############################################################################


def get_num_nodes(avg_num_nodes: int, avg_degree: int) -> int:
min_num_nodes = max(3 * avg_num_nodes // 4, avg_degree)
def get_num_nodes(avg_num_nodes: int, avg_degree: float) -> int:
min_num_nodes = max(3 * avg_num_nodes // 4, int(avg_degree))
max_num_nodes = 5 * avg_num_nodes // 4
return random.randint(min_num_nodes, max_num_nodes)

Expand All @@ -243,10 +244,15 @@ def get_num_channels(num_channels) -> int:
return random.randint(min_num_channels, max_num_channels)


def get_edge_index(num_src_nodes: int, num_dst_nodes: int, avg_degree: int,
is_undirected: bool = False,
remove_loops: bool = False) -> torch.Tensor:
num_edges = num_src_nodes * avg_degree
def get_edge_index(
num_src_nodes: int,
num_dst_nodes: int,
avg_degree: float,
is_undirected: bool = False,
remove_loops: bool = False,
) -> Tensor:

num_edges = int(num_src_nodes * avg_degree)
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.int64)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.int64)
edge_index = torch.stack([row, col], dim=0)
Expand Down

0 comments on commit c7483ac

Please sign in to comment.