Skip to content

Commit

Permalink
Polishing by linters
Browse files Browse the repository at this point in the history
  • Loading branch information
alinzh committed Dec 14, 2024
1 parent dc3818c commit 4214593
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 217 deletions.
3 changes: 1 addition & 2 deletions mpds_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import sys

from .retrieve_MPDS import MPDSDataTypes, APIError, MPDSDataRetrieval
Expand All @@ -7,4 +6,4 @@

MIN_PY_VER = (3, 5)

assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER)
assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER)
30 changes: 17 additions & 13 deletions mpds_client/errors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@

class APIError(Exception):
"""
Simple error handling
"""

codes = {
204: 'No Results',
400: 'Bad Request',
401: 'Unauthorized',
402: 'Unauthorized (Payment Required)',
403: 'Forbidden',
404: 'Not Found',
413: 'Too Much Data Given',
429: 'Too Many Requests (Rate Limiting)',
500: 'Internal Server Error',
501: 'Not Implemented',
503: 'Service Unavailable'
204: "No Results",
400: "Bad Request",
401: "Unauthorized",
402: "Unauthorized (Payment Required)",
403: "Forbidden",
404: "Not Found",
413: "Too Much Data Given",
429: "Too Many Requests (Rate Limiting)",
500: "Internal Server Error",
501: "Not Implemented",
503: "Service Unavailable",
}

def __init__(self, msg, code=0):
Expand All @@ -23,4 +23,8 @@ def __init__(self, msg, code=0):
self.code = code

def __str__(self):
return "HTTP error code %s: %s (%s)" % (self.code, self.codes.get(self.code, 'Communication Error'), self.msg)
return "HTTP error code %s: %s (%s)" % (
self.code,
self.codes.get(self.code, "Communication Error"),
self.msg,
)
80 changes: 51 additions & 29 deletions mpds_client/export_MPDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Utilities for convenient
exporting the MPDS data
"""

import os
import random
import ujson as json
Expand All @@ -10,13 +11,12 @@


class MPDSExport(object):

export_dir = "/tmp/_MPDS"

human_names = {
'length': 'Bond lengths, A',
'occurrence': 'Counts',
'bandgap': 'Band gap, eV'
"length": "Bond lengths, A",
"occurrence": "Counts",
"bandgap": "Band gap, eV",
}

@classmethod
Expand All @@ -32,7 +32,11 @@ def _gen_basename(cls):
basename = []
random.seed()
for _ in range(12):
basename.append(random.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"))
basename.append(
random.choice(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
)
return "".join(basename)

@classmethod
Expand All @@ -42,7 +46,7 @@ def _get_title(cls, term: Union[str, int]):
return cls.human_names.get(term, term.capitalize())

@classmethod
def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
def save_plot(cls, data, columns, plottype, fmt="json", **kwargs):
"""
Exports the data in the following formats for plotting:
Expand All @@ -59,51 +63,67 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
if not all(col in data.columns for col in columns):
raise ValueError("Some specified columns are not in the DataFrame")

if fmt == 'csv':
if fmt == "csv":
# export to CSV
fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".csv")
with open(fmt_export, "w") as f_export:
f_export.write(",".join(columns) + "\n")
for row in data.select(columns).iter_rows():
f_export.write(",".join(map(str, row)) + "\n")

elif fmt == 'json':
elif fmt == "json":
# export to JSON
fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".json")
with open(fmt_export, "w") as f_export:
if plottype == 'bar':
if plottype == "bar":
# bar plot payload
plot["payload"] = {
"x": [data[columns[0]].to_list()],
"y": data[columns[1]].to_list(),
"xtitle": cls._get_title(columns[0]),
"ytitle": cls._get_title(columns[1])
"ytitle": cls._get_title(columns[1]),
}

elif plottype == 'plot3d':
elif plottype == "plot3d":
# 3D plot payload
plot["payload"] = {
"points": {"x": [], "y": [], "z": [], "labels": []},
"meshes": [],
"xtitle": cls._get_title(columns[0]),
"ytitle": cls._get_title(columns[1]),
"ztitle": cls._get_title(columns[2])
"ztitle": cls._get_title(columns[2]),
}
recent_mesh = None
for row in data.iter_rows():
plot["payload"]["points"]["x"].append(row[data.columns.index(columns[0])])
plot["payload"]["points"]["y"].append(row[data.columns.index(columns[1])])
plot["payload"]["points"]["z"].append(row[data.columns.index(columns[2])])
plot["payload"]["points"]["labels"].append(row[data.columns.index(columns[3])])
plot["payload"]["points"]["x"].append(
row[data.columns.index(columns[0])]
)
plot["payload"]["points"]["y"].append(
row[data.columns.index(columns[1])]
)
plot["payload"]["points"]["z"].append(
row[data.columns.index(columns[2])]
)
plot["payload"]["points"]["labels"].append(
row[data.columns.index(columns[3])]
)

if row[data.columns.index(columns[4])] != recent_mesh:
plot["payload"]["meshes"].append({"x": [], "y": [], "z": []})
plot["payload"]["meshes"].append(
{"x": [], "y": [], "z": []}
)
recent_mesh = row[data.columns.index(columns[4])]

if plot["payload"]["meshes"]:
plot["payload"]["meshes"][-1]["x"].append(row[data.columns.index(columns[0])])
plot["payload"]["meshes"][-1]["y"].append(row[data.columns.index(columns[1])])
plot["payload"]["meshes"][-1]["z"].append(row[data.columns.index(columns[2])])
plot["payload"]["meshes"][-1]["x"].append(
row[data.columns.index(columns[0])]
)
plot["payload"]["meshes"][-1]["y"].append(
row[data.columns.index(columns[1])]
)
plot["payload"]["meshes"][-1]["z"].append(
row[data.columns.index(columns[2])]
)
else:
raise RuntimeError(f"Error: {plottype} is an unknown plot type")

Expand All @@ -116,8 +136,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
else:
raise ValueError(f"Unsupported format: {fmt}")

return fmt_export

return fmt_export

@classmethod
def save_df(cls, frame, tag):
Expand All @@ -126,22 +145,25 @@ def save_df(cls, frame, tag):
raise TypeError("Input frame must be a Polars DataFrame")

if tag is None:
tag = '-'
tag = "-"

pkl_export = os.path.join(cls.export_dir, f'df{tag}_{cls._gen_basename()}.parquet')
frame.write_parquet(pkl_export) # cos pickle is not supported in polars
pkl_export = os.path.join(
cls.export_dir, "df" + str(tag) + "_" + cls._gen_basename() + ".pkl"
)
frame.write_parquet(pkl_export) # cos pickle is not supported in polars
return pkl_export

@classmethod
def save_model(cls, skmodel, tag):

import _pickle as cPickle

cls._verify_export_dir()
if tag is None:
tag = '-'
tag = "-"

pkl_export = os.path.join(cls.export_dir, 'ml' + str(tag) + '_' + cls._gen_basename() + ".pkl")
with open(pkl_export, 'wb') as f:
pkl_export = os.path.join(
cls.export_dir, "ml" + str(tag) + "_" + cls._gen_basename() + ".pkl"
)
with open(pkl_export, "wb") as f:
cPickle.dump(skmodel, f)
return pkl_export
Loading

0 comments on commit 4214593

Please sign in to comment.