-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_predictions.py
64 lines (52 loc) · 1.86 KB
/
make_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
from pathlib import Path
import pandas as pd
from joblib import load
from mlsaft.extras.utils.molecular_fingerprints import compute_morgan_fingerprints
from wandb.apis import Api
def main(smiles_list_path: str, save_path: str = "predictions.csv"):
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Download wandb artifact
logger.info("Downloading model from wandb")
api = Api()
run = api.run("ceb-sre/dl4thermo/2eftwbx2")
artifacts = run.logged_artifacts()
model_path = None
for artifact in artifacts:
if artifact.type == "model":
model_path = artifact.download()
if model_path is None:
raise ValueError("Model not found")
model_path = Path(model_path) / "model.pkl"
# Load model
logger.info("Loading model")
with open(model_path, "rb") as f:
model = load(f)
# Load smiles
logger.info("Loading smiles")
with open(smiles_list_path, "r") as f:
smiles_list = f.readlines()
# Create fingerprints
logger.info("Computing fingerprints")
fps = compute_morgan_fingerprints(smiles_list)
# Make prediction
target_columns = ["m", "sigma", "epsilon_k", "epsilonAB", "KAB"]
logger.info("Making predictions")
preds = model.predict(fps)
df = pd.DataFrame(preds, columns=target_columns)
# Save predictions to csv
logger.info("Saving predictions")
df["smiles"] = smiles_list
df.to_csv(save_path, index=False)
if __name__ == "__main__":
# Use argparse
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("smiles_list_path", type=str)
parser.add_argument("--save_path", type=str, default="predictions.csv")
args = parser.parse_args()
main(args.smiles_list_path, args.save_path)