diff --git a/mpds_client/__init__.py b/mpds_client/__init__.py index aa26463..00f8de4 100755 --- a/mpds_client/__init__.py +++ b/mpds_client/__init__.py @@ -1,4 +1,3 @@ - import sys from .retrieve_MPDS import MPDSDataTypes, APIError, MPDSDataRetrieval @@ -7,4 +6,4 @@ MIN_PY_VER = (3, 5) -assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER) \ No newline at end of file +assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER) diff --git a/mpds_client/errors.py b/mpds_client/errors.py index d20cdf1..de3d0ef 100755 --- a/mpds_client/errors.py +++ b/mpds_client/errors.py @@ -1,20 +1,20 @@ - class APIError(Exception): """ Simple error handling """ + codes = { - 204: 'No Results', - 400: 'Bad Request', - 401: 'Unauthorized', - 402: 'Unauthorized (Payment Required)', - 403: 'Forbidden', - 404: 'Not Found', - 413: 'Too Much Data Given', - 429: 'Too Many Requests (Rate Limiting)', - 500: 'Internal Server Error', - 501: 'Not Implemented', - 503: 'Service Unavailable' + 204: "No Results", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Unauthorized (Payment Required)", + 403: "Forbidden", + 404: "Not Found", + 413: "Too Much Data Given", + 429: "Too Many Requests (Rate Limiting)", + 500: "Internal Server Error", + 501: "Not Implemented", + 503: "Service Unavailable", } def __init__(self, msg, code=0): @@ -23,4 +23,8 @@ def __init__(self, msg, code=0): self.code = code def __str__(self): - return "HTTP error code %s: %s (%s)" % (self.code, self.codes.get(self.code, 'Communication Error'), self.msg) \ No newline at end of file + return "HTTP error code %s: %s (%s)" % ( + self.code, + self.codes.get(self.code, "Communication Error"), + self.msg, + ) diff --git a/mpds_client/export_MPDS.py b/mpds_client/export_MPDS.py index f2858eb..9f2ee57 100755 --- a/mpds_client/export_MPDS.py +++ b/mpds_client/export_MPDS.py @@ -2,6 +2,7 @@ Utilities for convenient exporting the MPDS data """ + import os import random import ujson as json @@ -10,13 +11,12 @@ class MPDSExport(object): - export_dir = "/tmp/_MPDS" human_names = { - 'length': 'Bond lengths, A', - 'occurrence': 'Counts', - 'bandgap': 'Band gap, eV' + "length": "Bond lengths, A", + "occurrence": "Counts", + "bandgap": "Band gap, eV", } @classmethod @@ -32,7 +32,11 @@ def _gen_basename(cls): basename = [] random.seed() for _ in range(12): - basename.append(random.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")) + basename.append( + random.choice( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + ) + ) return "".join(basename) @classmethod @@ -42,7 +46,7 @@ def _get_title(cls, term: Union[str, int]): return cls.human_names.get(term, term.capitalize()) @classmethod - def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): + def save_plot(cls, data, columns, plottype, fmt="json", **kwargs): """ Exports the data in the following formats for plotting: @@ -59,7 +63,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): if not all(col in data.columns for col in columns): raise ValueError("Some specified columns are not in the DataFrame") - if fmt == 'csv': + if fmt == "csv": # export to CSV fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".csv") with open(fmt_export, "w") as f_export: @@ -67,43 +71,59 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): for row in data.select(columns).iter_rows(): f_export.write(",".join(map(str, row)) + "\n") - elif fmt == 'json': + elif fmt == "json": # export to JSON fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".json") with open(fmt_export, "w") as f_export: - if plottype == 'bar': + if plottype == "bar": # bar plot payload plot["payload"] = { "x": [data[columns[0]].to_list()], "y": data[columns[1]].to_list(), "xtitle": cls._get_title(columns[0]), - "ytitle": cls._get_title(columns[1]) + "ytitle": cls._get_title(columns[1]), } - elif plottype == 'plot3d': + elif plottype == "plot3d": # 3D plot payload plot["payload"] = { "points": {"x": [], "y": [], "z": [], "labels": []}, "meshes": [], "xtitle": cls._get_title(columns[0]), "ytitle": cls._get_title(columns[1]), - "ztitle": cls._get_title(columns[2]) + "ztitle": cls._get_title(columns[2]), } recent_mesh = None for row in data.iter_rows(): - plot["payload"]["points"]["x"].append(row[data.columns.index(columns[0])]) - plot["payload"]["points"]["y"].append(row[data.columns.index(columns[1])]) - plot["payload"]["points"]["z"].append(row[data.columns.index(columns[2])]) - plot["payload"]["points"]["labels"].append(row[data.columns.index(columns[3])]) + plot["payload"]["points"]["x"].append( + row[data.columns.index(columns[0])] + ) + plot["payload"]["points"]["y"].append( + row[data.columns.index(columns[1])] + ) + plot["payload"]["points"]["z"].append( + row[data.columns.index(columns[2])] + ) + plot["payload"]["points"]["labels"].append( + row[data.columns.index(columns[3])] + ) if row[data.columns.index(columns[4])] != recent_mesh: - plot["payload"]["meshes"].append({"x": [], "y": [], "z": []}) + plot["payload"]["meshes"].append( + {"x": [], "y": [], "z": []} + ) recent_mesh = row[data.columns.index(columns[4])] if plot["payload"]["meshes"]: - plot["payload"]["meshes"][-1]["x"].append(row[data.columns.index(columns[0])]) - plot["payload"]["meshes"][-1]["y"].append(row[data.columns.index(columns[1])]) - plot["payload"]["meshes"][-1]["z"].append(row[data.columns.index(columns[2])]) + plot["payload"]["meshes"][-1]["x"].append( + row[data.columns.index(columns[0])] + ) + plot["payload"]["meshes"][-1]["y"].append( + row[data.columns.index(columns[1])] + ) + plot["payload"]["meshes"][-1]["z"].append( + row[data.columns.index(columns[2])] + ) else: raise RuntimeError(f"Error: {plottype} is an unknown plot type") @@ -116,8 +136,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): else: raise ValueError(f"Unsupported format: {fmt}") - return fmt_export - + return fmt_export @classmethod def save_df(cls, frame, tag): @@ -126,22 +145,25 @@ def save_df(cls, frame, tag): raise TypeError("Input frame must be a Polars DataFrame") if tag is None: - tag = '-' + tag = "-" - pkl_export = os.path.join(cls.export_dir, f'df{tag}_{cls._gen_basename()}.parquet') - frame.write_parquet(pkl_export) # cos pickle is not supported in polars + pkl_export = os.path.join( + cls.export_dir, "df" + str(tag) + "_" + cls._gen_basename() + ".pkl" + ) + frame.write_parquet(pkl_export) # cos pickle is not supported in polars return pkl_export @classmethod def save_model(cls, skmodel, tag): - import _pickle as cPickle cls._verify_export_dir() if tag is None: - tag = '-' + tag = "-" - pkl_export = os.path.join(cls.export_dir, 'ml' + str(tag) + '_' + cls._gen_basename() + ".pkl") - with open(pkl_export, 'wb') as f: + pkl_export = os.path.join( + cls.export_dir, "ml" + str(tag) + "_" + cls._gen_basename() + ".pkl" + ) + with open(pkl_export, "wb") as f: cPickle.dump(skmodel, f) return pkl_export diff --git a/mpds_client/retrieve_MPDS.py b/mpds_client/retrieve_MPDS.py index e8f7dde..5cdf34c 100755 --- a/mpds_client/retrieve_MPDS.py +++ b/mpds_client/retrieve_MPDS.py @@ -18,22 +18,26 @@ try: from pymatgen.core.structure import Structure from pymatgen.core.lattice import Lattice + use_pmg = True -except ImportError: pass +except ImportError: + pass try: from ase import Atom from ase.spacegroup import crystal + use_ase = True -except ImportError: pass +except ImportError: + pass if not use_pmg and not use_ase: warnings.warn("Crystal structure treatment unavailable") -__author__ = 'Evgeny Blokhin ' -__copyright__ = 'Copyright (c) 2020, Evgeny Blokhin, Tilde Materials Informatics' -__license__ = 'MIT' +__author__ = "Evgeny Blokhin " +__copyright__ = "Copyright (c) 2020, Evgeny Blokhin, Tilde Materials Informatics" +__license__ = "MIT" class MPDSDataTypes(object): @@ -67,46 +71,51 @@ class MPDSDataRetrieval(object): *or* jsonobj = client.get_data({"formula":"SrTiO3"}, fields={}) """ + default_fields = { - 'S': [ - 'phase_id', - 'chemical_formula', - 'sg_n', - 'entry', - lambda: 'crystal structure', - lambda: 'angstrom' + "S": [ + "phase_id", + "chemical_formula", + "sg_n", + "entry", + lambda: "crystal structure", + lambda: "angstrom", ], - 'P': [ - 'sample.material.phase_id', - 'sample.material.chemical_formula', - 'sample.material.condition[0].scalar[0].value', - 'sample.material.entry', - 'sample.measurement[0].property.name', - 'sample.measurement[0].property.units', - 'sample.measurement[0].property.scalar' + "P": [ + "sample.material.phase_id", + "sample.material.chemical_formula", + "sample.material.condition[0].scalar[0].value", + "sample.material.entry", + "sample.measurement[0].property.name", + "sample.measurement[0].property.units", + "sample.measurement[0].property.scalar", ], - 'C': [ + "C": [ lambda: None, - 'title', + "title", lambda: None, - 'entry', - lambda: 'phase diagram', - 'naxes', - 'arity' - ] + "entry", + lambda: "phase diagram", + "naxes", + "arity", + ], } - default_titles = ['Phase', 'Formula', 'SG', 'Entry', 'Property', 'Units', 'Value'] + default_titles = ["Phase", "Formula", "SG", "Entry", "Property", "Units", "Value"] endpoint = "https://api.mpds.io/v0/download/facet" pagesize = 1000 - maxnpages = 120 # one hit may reach 50kB in RAM, consider pagesize*maxnpages*50kB free RAM - maxnphases = 1500 # more phases require additional requests + maxnpages = ( + 120 # one hit may reach 50kB in RAM, consider pagesize*maxnpages*50kB free RAM + ) + maxnphases = 1500 # more phases require additional requests chillouttime = 2 # please, do not use values < 2, because the server may burn out verbose = True debug = False - def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None): + def __init__( + self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None + ): """ MPDS API consumer constructor @@ -116,7 +125,7 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug= Returns: None """ - self.api_key = api_key if api_key else os.environ['MPDS_KEY'] + self.api_key = api_key if api_key else os.environ["MPDS_KEY"] self.network = httplib2.Http() @@ -126,39 +135,42 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug= self.debug = debug or self.debug def _request(self, query, phases=None, page=0, pagesize=None): - - phases = ','.join([str(int(x)) for x in phases]) if phases else '' - - uri = self.endpoint + '?' + urlencode({ - 'q': json.dumps(query), - 'phases': phases, - 'page': page, - 'pagesize': pagesize or self.pagesize, - 'dtype': self.dtype - }) + phases = ",".join([str(int(x)) for x in phases]) if phases else "" + + uri = ( + self.endpoint + + "?" + + urlencode( + { + "q": json.dumps(query), + "phases": phases, + "page": page, + "pagesize": pagesize or self.pagesize, + "dtype": self.dtype, + } + ) + ) if self.debug: - print('curl -XGET -HKey:%s \"%s\"' % (self.api_key, uri)) + print('curl -XGET -HKey:%s "%s"' % (self.api_key, uri)) response, content = self.network.request( - uri=uri, - method='GET', - headers={'Key': self.api_key} + uri=uri, method="GET", headers={"Key": self.api_key} ) if response.status != 200: - return {'error': content, 'code': response.status} + return {"error": content, "code": response.status} try: content = json.loads(content) except: - return {'error': 'Unreadable data obtained'} + return {"error": "Unreadable data obtained"} - if content.get('error'): - return {'error': content['error']} + if content.get("error"): + return {"error": content["error"]} - if not content['out']: - return {'error': 'No hits', 'code': 204} + if not content["out"]: + return {"error": "No hits", "code": 204} return content @@ -171,8 +183,8 @@ def _massage(self, array, fields): for item in array: filtered = [] - for object_type in ['S', 'P', 'C']: - if item['object_type'] == object_type: + for object_type in ["S", "P", "C"]: + if item["object_type"] == object_type: for expr in fields.get(object_type, []): if isinstance(expr, jmespath.parser.ParsedResult): filtered.append(expr.search(item)) @@ -201,16 +213,16 @@ def count_data(self, search, phases=None, **kwargs): """ result = self._request(search, phases=phases, pagesize=10) - if result['error']: - raise APIError(result['error'], result.get('code', 0)) + if result["error"]: + raise APIError(result["error"], result.get("code", 0)) - if result['npages'] > self.maxnpages: + if result["npages"] > self.maxnpages: warnings.warn( - "\r\nDataset is too big, you may risk to change maxnpages from %s to %s" % \ - (self.maxnpages, int(math.ceil(result['count']/self.pagesize))) + "\r\nDataset is too big, you may risk to change maxnpages from %s to %s" + % (self.maxnpages, int(math.ceil(result["count"] / self.pagesize))) ) - return result['count'] + return result["count"] def get_data(self, search, phases=None, fields=default_fields): """ @@ -232,56 +244,68 @@ def get_data(self, search, phases=None, fields=default_fields): documented at https://developer.mpds.io/#JSON-schemata """ output = [] - fields = { - key: [jmespath.compile(item) if isinstance(item, str) else item() for item in value] - for key, value in fields.items() - } if fields else None + fields = ( + { + key: [ + jmespath.compile(item) if isinstance(item, str) else item() + for item in value + ] + for key, value in fields.items() + } + if fields + else None + ) tot_count = 0 phases = list(set(phases)) if phases else [] if len(phases) > self.maxnphases: - all_phases = array_split(phases, int(math.ceil( - len(phases)/self.maxnphases - ))) - else: all_phases = [phases] + all_phases = array_split( + phases, int(math.ceil(len(phases) / self.maxnphases)) + ) + else: + all_phases = [phases] nsteps = len(all_phases) for step, current_phases in enumerate(all_phases, start=1): - counter, hits_count = 0, 0 while True: - result = self._request(search, phases=list(current_phases), page=counter) - if result['error']: - raise APIError(result['error'], result.get('code', 0)) + result = self._request( + search, phases=list(current_phases), page=counter + ) + if result["error"]: + raise APIError(result["error"], result.get("code", 0)) - if result['npages'] > self.maxnpages: + if result["npages"] > self.maxnpages: raise APIError( - "Too many hits (%s > %s), please, be more specific" % \ - (result['count'], self.maxnpages * self.pagesize), - 2 + "Too many hits (%s > %s), please, be more specific" + % (result["count"], self.maxnpages * self.pagesize), + 2, ) - output.extend(self._massage(result['out'], fields)) + output.extend(self._massage(result["out"], fields)) - if hits_count and hits_count != result['count']: - raise APIError("API error: hits count has been changed during the query") + if hits_count and hits_count != result["count"]: + raise APIError( + "API error: hits count has been changed during the query" + ) - hits_count = result['count'] + hits_count = result["count"] time.sleep(self.chillouttime) - if counter == result['npages'] - 1: + if counter == result["npages"] - 1: break counter += 1 if self.verbose: - sys.stdout.write("\r\t%d%% of step %s from %s" % ( - (counter/result['npages']) * 100, step, nsteps) - ) + sys.stdout.write( + "\r\t%d%% of step %s from %s" + % ((counter / result["npages"]) * 100, step, nsteps) + ) sys.stdout.flush() tot_count += hits_count @@ -311,24 +335,24 @@ def get_dataframe(self, *args, **kwargs): Returns: (object) Polars dataframe object containing the results """ - columns = kwargs.get('columns') + columns = kwargs.get("columns") if columns: - del kwargs['columns'] + del kwargs["columns"] else: columns = self.default_titles data = self.get_data(*args, **kwargs) return pl.DataFrame(data, schema=columns) - def get_crystals(self, search, phases=None, flavor='pmg', **kwargs): + def get_crystals(self, search, phases=None, flavor="pmg", **kwargs): search["props"] = "atomic structure" crystals = [] for crystal_struct in self.get_data( - search, - phases, - fields={'S':['cell_abc', 'sg_n', 'basis_noneq', 'els_noneq']}, - **kwargs + search, + phases, + fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]}, + **kwargs, ): crobj = self.compile_crystal(crystal_struct, flavor) if crobj is not None: @@ -337,7 +361,7 @@ def get_crystals(self, search, phases=None, flavor='pmg', **kwargs): return crystals @staticmethod - def compile_crystal(datarow, flavor='pmg'): + def compile_crystal(datarow, flavor="pmg"): """ Helper method for representing the MPDS crystal structures in two flavors: either as a Pymatgen Structure object, or as an ASE Atoms object. @@ -376,20 +400,22 @@ def compile_crystal(datarow, flavor='pmg'): if len(datarow) < 4: raise ValueError( "Must supply a data row that ends with the entries " - "'cell_abc', 'sg_n', 'basis_noneq', 'els_noneq'") + "'cell_abc', 'sg_n', 'basis_noneq', 'els_noneq'" + ) - cell_abc, sg_n, basis_noneq, els_noneq = \ - datarow[-4], int(datarow[-3]), datarow[-2], datarow[-1] + cell_abc, sg_n, basis_noneq, els_noneq = ( + datarow[-4], + int(datarow[-3]), + datarow[-2], + datarow[-1], + ) - if flavor == 'pmg' and use_pmg: + if flavor == "pmg" and use_pmg: return Structure.from_spacegroup( - sg_n, - Lattice.from_parameters(*cell_abc), - els_noneq, - basis_noneq + sg_n, Lattice.from_parameters(*cell_abc), els_noneq, basis_noneq ) - elif flavor == 'ase' and use_ase: + elif flavor == "ase" and use_ase: atom_data = [] for num, i in enumerate(basis_noneq): @@ -400,8 +426,8 @@ def compile_crystal(datarow, flavor='pmg'): spacegroup=sg_n, cellpar=cell_abc, primitive_cell=True, - onduplicates='replace' + onduplicates="replace", ) - else: raise APIError("Crystal structure treatment unavailable") - + else: + raise APIError("Crystal structure treatment unavailable") diff --git a/mpds_client/test_export_MPDS.py b/mpds_client/test_export_MPDS.py index 901087e..fbe8d94 100644 --- a/mpds_client/test_export_MPDS.py +++ b/mpds_client/test_export_MPDS.py @@ -1,58 +1,55 @@ import unittest import os import polars as pl -from export_MPDS import MPDSExport +from export_MPDS import MPDSExport class TestMPDSExport(unittest.TestCase): def test_save_plot_csv(self): """Test saving a plot in CSV format.""" - data = pl.DataFrame({ - "length": [1.2, 1.5, 1.8, 2.0, 2.2], - "occurrence": [10, 15, 8, 20, 12] - }) + data = pl.DataFrame( + {"length": [1.2, 1.5, 1.8, 2.0, 2.2], "occurrence": [10, 15, 8, 20, 12]} + ) columns = ["length", "occurrence"] plottype = "bar" - exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='csv') + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="csv") self.assertTrue(os.path.isfile(exported_file)) self.assertTrue(exported_file.endswith(".csv")) def test_save_plot_json(self): """Test saving a plot in JSON format.""" - data = pl.DataFrame({ - "length": [1.2, 1.5, 1.8, 2.0, 2.2], - "occurrence": [10, 15, 8, 20, 12] - }) + data = pl.DataFrame( + {"length": [1.2, 1.5, 1.8, 2.0, 2.2], "occurrence": [10, 15, 8, 20, 12]} + ) columns = ["length", "occurrence"] plottype = "bar" - exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json') + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="json") self.assertTrue(os.path.isfile(exported_file)) self.assertTrue(exported_file.endswith(".json")) def test_save_plot_3d_json(self): """Test saving a 3D plot in JSON format.""" - data = pl.DataFrame({ - "x": [1, 2, 3, 4], - "y": [5, 6, 7, 8], - "z": [9, 10, 11, 12], - "labels": ["A", "B", "C", "D"], - "meshes_id": [1, 1, 2, 2] - }) + data = pl.DataFrame( + { + "x": [1, 2, 3, 4], + "y": [5, 6, 7, 8], + "z": [9, 10, 11, 12], + "labels": ["A", "B", "C", "D"], + "meshes_id": [1, 1, 2, 2], + } + ) columns = ["x", "y", "z", "labels", "meshes_id"] plottype = "plot3d" - exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json') + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="json") self.assertTrue(os.path.isfile(exported_file)) self.assertTrue(exported_file.endswith(".json")) def test_save_df(self): """Test saving Polars DataFrame.""" - data = pl.DataFrame({ - "column1": [1, 2, 3], - "column2": [4, 5, 6] - }) + data = pl.DataFrame({"column1": [1, 2, 3], "column2": [4, 5, 6]}) tag = "test" exported_file = MPDSExport.save_df(data, tag) diff --git a/mpds_client/test_retrieve_MPDS.py b/mpds_client/test_retrieve_MPDS.py index 3cecae7..898354b 100755 --- a/mpds_client/test_retrieve_MPDS.py +++ b/mpds_client/test_retrieve_MPDS.py @@ -1,5 +1,5 @@ import unittest -#import warnings +# import warnings import polars as pl @@ -15,22 +15,23 @@ class MPDSDataRetrievalTest(unittest.TestCase): @classmethod def setUpClass(cls): - #warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") + # warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") network = httplib2.Http() - response, content = network.request('https://developer.mpds.io/mpds.schema.json') + response, content = network.request( + "https://developer.mpds.io/mpds.schema.json" + ) assert response.status == 200 cls.schema = json.loads(content) Draft4Validator.check_schema(cls.schema) def test_valid_answer(self): - query = { "elements": "K-Ag", "classes": "iodide", "props": "heat capacity", - "lattices": "cubic" + "lattices": "cubic", } client = MPDSDataRetrieval() @@ -40,30 +41,28 @@ def test_valid_answer(self): validate(answer, self.schema) except ValidationError as e: self.fail( - "The item: \r\n\r\n %s \r\n\r\n has an issue: \r\n\r\n %s" % ( - e.instance, e.context - ) + "The item: \r\n\r\n %s \r\n\r\n has an issue: \r\n\r\n %s" + % (e.instance, e.context) ) def test_crystal_structure(self): - query = { "elements": "Ti-O", "classes": "binary", "props": "atomic structure", - "sgs": 136 + "sgs": 136, } client = MPDSDataRetrieval() ntot = client.count_data(query) self.assertTrue(150 < ntot < 175) - for crystal_struct in client.get_data(query, fields={ - 'S': ['cell_abc', 'sg_n', 'basis_noneq', 'els_noneq']}): - + for crystal_struct in client.get_data( + query, fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]} + ): self.assertEqual(crystal_struct[1], 136) - ase_obj = MPDSDataRetrieval.compile_crystal(crystal_struct, 'ase') + ase_obj = MPDSDataRetrieval.compile_crystal(crystal_struct, "ase") if ase_obj: self.assertEqual(len(ase_obj), 6) @@ -72,20 +71,22 @@ def test_get_crystals(self): "elements": "Ti-O", "classes": "binary", "props": "atomic structure", - "sgs": 136 + "sgs": 136, } client = MPDSDataRetrieval() ntot = client.count_data(query) logging.debug(f"Value of ntot: {ntot}") self.assertTrue(150 < ntot < 175) - crystals = client.get_crystals(query, flavor='ase') + crystals = client.get_crystals(query, flavor="ase") for crystal in crystals: self.assertIsNotNone(crystal) # now try getting the crystal from the phase_id(s) - phase_ids = {_[0] for _ in client.get_data(query, fields={'S': ['phase_id']})} - crystals_from_phase_ids = client.get_crystals(query, phases=phase_ids, flavor='ase') + phase_ids = {_[0] for _ in client.get_data(query, fields={"S": ["phase_id"]})} + crystals_from_phase_ids = client.get_crystals( + query, phases=phase_ids, flavor="ase" + ) self.assertEqual(len(crystals), len(crystals_from_phase_ids)) @@ -95,15 +96,11 @@ def test_retrieval_of_phases(self): in two ways: maxnphases = changed and maxnphases = default """ - query_a = { - "elements": "O", - "classes": "binary", - "props": "band gap" - } + query_a = {"elements": "O", "classes": "binary", "props": "band gap"} query_b = { "elements": "O", "classes": "binary", - "props": "isothermal bulk modulus" + "props": "isothermal bulk modulus", } client_one = MPDSDataRetrieval() @@ -111,24 +108,28 @@ def test_retrieval_of_phases(self): answer_one = client_one.get_dataframe( query_a, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'] + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], ) - if not(isinstance(answer_one, pl.DataFrame)): + if not (isinstance(answer_one, pl.DataFrame)): print(type(answer_one)) raise ValueError("answer_one is not a Polars DataFrame", type(answer_one)) - answer_one = answer_one.filter(pl.col('Phid').is_not_null()) - answer_one = answer_one.with_columns(pl.col('Phid').cast(pl.Int32)) - phases_one = answer_one['Phid'].to_list() - + answer_one = answer_one.filter(pl.col("Phid").is_not_null()) + answer_one = answer_one.with_columns(pl.col("Phid").cast(pl.Int32)) + phases_one = answer_one["Phid"].to_list() + self.assertTrue(len(phases_one) > client_one.maxnphases) result_one = client_one.get_dataframe( query_b, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'], - phases=phases_one + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], + phases=phases_one, ) client_two = MPDSDataRetrieval() @@ -136,23 +137,29 @@ def test_retrieval_of_phases(self): answer_two = client_two.get_dataframe( query_a, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'] + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], ) - if not(isinstance(answer_one, pl.DataFrame)): + if not (isinstance(answer_one, pl.DataFrame)): print(type(answer_two)) - raise ValueError("answer_one is not a Polars DataFrame, is", type(answer_two)) - - answer_two = answer_two.filter(pl.col('Phid').is_not_null()) - phases_two = answer_two['Phid'].cast(pl.Int32).to_list() + raise ValueError( + "answer_one is not a Polars DataFrame, is", type(answer_two) + ) + + answer_two = answer_two.filter(pl.col("Phid").is_not_null()) + phases_two = answer_two["Phid"].cast(pl.Int32).to_list() self.assertTrue(len(phases_two) < client_two.maxnphases) result_two = client_two.get_dataframe( query_b, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'], - phases=phases_two + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], + phases=phases_two, ) self.assertEqual(len(result_one), len(result_two)) @@ -160,11 +167,14 @@ def test_retrieval_of_phases(self): # check equality of result_one and result_two merge = pl.concat([result_one, result_two]) merge = merge.with_columns(pl.Series("index", range(len(merge)))) - merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg(pl.len()) + merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg( + pl.len() + ) idx = [x[0] for x in merge_gpby.iter_rows() if x[-1] == 1] self.assertTrue(merge.filter(pl.col("index").is_in(idx)).is_empty()) -if __name__ == "__main__": + +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main()