Skip to content

Commit

Permalink
Add script, notebook, minor folders structure debugging when calling …
Browse files Browse the repository at this point in the history
…some functions
  • Loading branch information
AdrienC21 committed Sep 18, 2023
1 parent 7c5680a commit da710be
Show file tree
Hide file tree
Showing 12 changed files with 1,591 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/apply_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
packages = site.getsitepackages()
site_packages = None
for p in packages:
if "site-packages" in p:
if ("dist-packages" in p) or ("site-packages" in p):
site_packages = p
break

Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changelog
==================================

0.3.1 (2023/09/18)
--------------------

- Debug folder structure when calling some functions

- Add example script and notebook

0.3.0 (2023/09/18)
--------------------

Expand Down
2 changes: 1 addition & 1 deletion ccsd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

__author__ = "Adrien Carrel"
__email__ = "[email protected]"
__version__ = "0.3.0"
__version__ = "0.3.1"

__all__ = ["src", "data"]
2 changes: 1 addition & 1 deletion ccsd/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def run(self) -> None:
ValueError: raise and error the experiment type is not one of [train, sample].
"""
# Get the configuration and the general configuration
config = get_config(self.args.config, self.args.seed)
config = get_config(self.args.config, self.args.seed, self.args.folder)
general_config = get_general_config()

# Print the initial message
Expand Down
5 changes: 3 additions & 2 deletions ccsd/src/parsers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
from easydict import EasyDict


def get_config(config: str, seed: int) -> EasyDict:
def get_config(config: str, seed: int, folder: str = "./") -> EasyDict:
"""Load the config file.
Args:
config (str): name of the config file.
seed (int): random seed (to be added to the config object).
folder (str, optional): folder where the config folder is located. Defaults to "./".
Returns:
EasyDict: configuration object.
"""
config_dir = os.path.join("config", f"{config}.yaml")
config_dir = os.path.join(folder, "config", f"{config}.yaml")
config = EasyDict(yaml.load(open(config_dir, "r"), Loader=yaml.FullLoader))
config.seed = seed

Expand Down
4 changes: 3 additions & 1 deletion ccsd/src/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,9 @@ def sample(self) -> None:
logger.log(f"GEN SEED: {self.config.sample.seed}")
load_seed(self.config.sample.seed)

train_smiles, test_smiles = load_smiles(self.configt.data.data)
train_smiles, test_smiles = load_smiles(
self.configt.data.data, self.config.folder
)
train_smiles, test_smiles = canonicalize_smiles(
train_smiles
), canonicalize_smiles(test_smiles)
Expand Down
9 changes: 6 additions & 3 deletions ccsd/src/utils/mol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ def canonicalize_smiles(smiles: List[str]) -> List[str]:
return [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles]


def load_smiles(dataset: str = "QM9") -> Tuple[List[str], List[str]]:
def load_smiles(
dataset: str = "QM9", folder: str = "./"
) -> Tuple[List[str], List[str]]:
"""Loads SMILES strings from a dataset and return train and test splits.
Args:
dataset (str, optional): smiles dataset to load. Defaults to "QM9".
folder (str, optional): folder where the data folder is located. Defaults to "./".
Raises:
ValueError: raise an error if dataset is not supported
Expand All @@ -124,9 +127,9 @@ def load_smiles(dataset: str = "QM9") -> Tuple[List[str], List[str]]:
else:
raise ValueError(f"Wrong dataset name {dataset} in load_smiles")

df = pd.read_csv(os.path.join("data", f"{dataset.lower()}.csv"))
df = pd.read_csv(os.path.join(folder, "data", f"{dataset.lower()}.csv"))

with open(os.path.join("data", f"valid_idx_{dataset.lower()}.json")) as f:
with open(os.path.join(folder, "data", f"valid_idx_{dataset.lower()}.json")) as f:
test_idx = json.load(f)

if dataset == "QM9": # special case for QM9
Expand Down
2 changes: 1 addition & 1 deletion config/general_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ project_name: "CCSD" # name of the project in wandb
entity: "a-carrel" # name of the entity in wandb
timezone: "Europe/London" # timezone to name the output files
print_initial: True # print an initial message with logo and current experiment
plotly_fig: False # if True, create plotly figures (rotating 3D plots, diffusion animation, etc)
plotly_fig: True # if True, create plotly figures (rotating 3D plots, diffusion animation, etc)
engine: "kaleido" # engine for the plotly plots. Windows users should use "kaleido" instead of "orca"
2 changes: 1 addition & 1 deletion config/sample_qm9_CC demonstration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ sampler:
n_steps: 1

sample:
# divide_batch: 4 # optional, only if RAM issue occurs
divide_batch: 4 # optional, only if RAM issue occurs
n_samples: 16 # param only for mol datasets
cc_nb_eval: 1000 # param only for cc datasets
use_ema: False
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main(args: argparse.Namespace) -> None:
"""

# Get the configuration and the general configuration
config = get_config(args.config, args.seed)
config = get_config(args.config, args.seed, args.folder)
general_config = get_general_config()

# Print the initial message
Expand Down
Loading

0 comments on commit da710be

Please sign in to comment.