Skip to content

Commit

Permalink
Solved problem with pyg loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed May 20, 2024
1 parent fd0c488 commit ac85988
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 44 deletions.
2 changes: 1 addition & 1 deletion configs/dataset/ZINC.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ _target_: topobenchmarkx.io.load.loaders.GraphLoader
# USE python train.py dataset.transforms.one_hot_node_degree_features.degrees_fields=x to run this config

defaults:
#- transforms/data_manipulations: node_feat_to_float
- transforms/data_manipulations: node_degrees
- transforms/[email protected]_hot_node_degree_features: one_hot_node_degree_features
- transforms: ${get_default_transform:graph,${model}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ transform_type: "data manipulation"

degrees_fields: "node_degrees"
features_fields: "x"
max_degrees: ${dataset.parameters.max_node_degree}
max_degree: ${dataset.parameters.max_node_degree}

2 changes: 1 addition & 1 deletion env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pip install git+https://github.com/pyt-team/TopoEmbedX.git

CUDA="cu121" # if available, select the CUDA version suitable for your system
# e.g. cpu, cu102, cu111, cu113, cu115
pip install torch_geometric==2.4.0
pip install fsspec==2024.5.0
pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA}
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
Expand Down
41 changes: 1 addition & 40 deletions topobenchmarkx/io/load/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,43 +145,4 @@ def process(self) -> None:
self._data_list = None # Reset cache.

assert isinstance(self._data, torch_geometric.data.Data)
self.save(self.data_list, self.processed_paths[0])

def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
warnings.warn(
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"delete '{self.processed_dir}' first")

f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
warnings.warn(
"The `pre_filter` argument differs from the one used in "
"the pre-processed version of this dataset. If you want to "
"make use of another pre-fitering technique, make sure to "
"delete '{self.processed_dir}' first")

if not self.force_reload and files_exist(self.processed_paths):
return

if self.log and 'pytest' not in sys.modules:
print('Processing...', file=sys.stderr)

makedirs(self.processed_dir)
self.process()

path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(_repr(self.pre_filter), path)

if self.log and 'pytest' not in sys.modules:
print('Done!', file=sys.stderr)

def _repr(obj: Any) -> str:
if obj is None:
return 'None'
return re.sub('(<.*?)\\s.*(>)', r'\1\2', str(obj))
self.save(self.data_list, self.processed_paths[0])
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
self,
max_degree: int,
cat: bool = False,
**kwargs,
) -> None:
self.max_degree = max_degree
self.cat = cat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, **kwargs):
self.type = "one_hot_degree_features"
self.deg_field = kwargs["degrees_fields"]
self.features_fields = kwargs["features_fields"]
self.transform = OneHotDegree(max_degree=kwargs["max_degrees"])
self.transform = OneHotDegree(**kwargs)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(type={self.type!r}, degrees_field={self.deg_field!r}, features_field={self.features_fields!r})"
Expand Down

0 comments on commit ac85988

Please sign in to comment.