From 421459305303300d66c06911cc1a41cb38284b2d Mon Sep 17 00:00:00 2001
From: alinzh <siberiangodness@mail.ru>
Date: Sat, 14 Dec 2024 11:54:33 +0000
Subject: [PATCH] Polishing by linters

---
 mpds_client/__init__.py           |   3 +-
 mpds_client/errors.py             |  30 ++--
 mpds_client/export_MPDS.py        |  80 ++++++----
 mpds_client/retrieve_MPDS.py      | 234 +++++++++++++++++-------------
 mpds_client/test_export_MPDS.py   |  43 +++---
 mpds_client/test_retrieve_MPDS.py | 102 +++++++------
 6 files changed, 275 insertions(+), 217 deletions(-)

diff --git a/mpds_client/__init__.py b/mpds_client/__init__.py
index aa26463..00f8de4 100755
--- a/mpds_client/__init__.py
+++ b/mpds_client/__init__.py
@@ -1,4 +1,3 @@
-
 import sys
 
 from .retrieve_MPDS import MPDSDataTypes, APIError, MPDSDataRetrieval
@@ -7,4 +6,4 @@
 
 MIN_PY_VER = (3, 5)
 
-assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER)
\ No newline at end of file
+assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER)
diff --git a/mpds_client/errors.py b/mpds_client/errors.py
index d20cdf1..de3d0ef 100755
--- a/mpds_client/errors.py
+++ b/mpds_client/errors.py
@@ -1,20 +1,20 @@
-
 class APIError(Exception):
     """
     Simple error handling
     """
+
     codes = {
-        204: 'No Results',
-        400: 'Bad Request',
-        401: 'Unauthorized',
-        402: 'Unauthorized (Payment Required)',
-        403: 'Forbidden',
-        404: 'Not Found',
-        413: 'Too Much Data Given',
-        429: 'Too Many Requests (Rate Limiting)',
-        500: 'Internal Server Error',
-        501: 'Not Implemented',
-        503: 'Service Unavailable'
+        204: "No Results",
+        400: "Bad Request",
+        401: "Unauthorized",
+        402: "Unauthorized (Payment Required)",
+        403: "Forbidden",
+        404: "Not Found",
+        413: "Too Much Data Given",
+        429: "Too Many Requests (Rate Limiting)",
+        500: "Internal Server Error",
+        501: "Not Implemented",
+        503: "Service Unavailable",
     }
 
     def __init__(self, msg, code=0):
@@ -23,4 +23,8 @@ def __init__(self, msg, code=0):
         self.code = code
 
     def __str__(self):
-        return "HTTP error code %s: %s (%s)" % (self.code, self.codes.get(self.code, 'Communication Error'), self.msg)
\ No newline at end of file
+        return "HTTP error code %s: %s (%s)" % (
+            self.code,
+            self.codes.get(self.code, "Communication Error"),
+            self.msg,
+        )
diff --git a/mpds_client/export_MPDS.py b/mpds_client/export_MPDS.py
index f2858eb..9f2ee57 100755
--- a/mpds_client/export_MPDS.py
+++ b/mpds_client/export_MPDS.py
@@ -2,6 +2,7 @@
 Utilities for convenient
 exporting the MPDS data
 """
+
 import os
 import random
 import ujson as json
@@ -10,13 +11,12 @@
 
 
 class MPDSExport(object):
-
     export_dir = "/tmp/_MPDS"
 
     human_names = {
-        'length': 'Bond lengths, A',
-        'occurrence': 'Counts',
-        'bandgap': 'Band gap, eV'
+        "length": "Bond lengths, A",
+        "occurrence": "Counts",
+        "bandgap": "Band gap, eV",
     }
 
     @classmethod
@@ -32,7 +32,11 @@ def _gen_basename(cls):
         basename = []
         random.seed()
         for _ in range(12):
-            basename.append(random.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"))
+            basename.append(
+                random.choice(
+                    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+                )
+            )
         return "".join(basename)
 
     @classmethod
@@ -42,7 +46,7 @@ def _get_title(cls, term: Union[str, int]):
         return cls.human_names.get(term, term.capitalize())
 
     @classmethod
-    def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
+    def save_plot(cls, data, columns, plottype, fmt="json", **kwargs):
         """
         Exports the data in the following formats for plotting:
 
@@ -59,7 +63,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
         if not all(col in data.columns for col in columns):
             raise ValueError("Some specified columns are not in the DataFrame")
 
-        if fmt == 'csv':
+        if fmt == "csv":
             # export to CSV
             fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".csv")
             with open(fmt_export, "w") as f_export:
@@ -67,43 +71,59 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
                 for row in data.select(columns).iter_rows():
                     f_export.write(",".join(map(str, row)) + "\n")
 
-        elif fmt == 'json':
+        elif fmt == "json":
             # export to JSON
             fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".json")
             with open(fmt_export, "w") as f_export:
-                if plottype == 'bar':
+                if plottype == "bar":
                     # bar plot payload
                     plot["payload"] = {
                         "x": [data[columns[0]].to_list()],
                         "y": data[columns[1]].to_list(),
                         "xtitle": cls._get_title(columns[0]),
-                        "ytitle": cls._get_title(columns[1])
+                        "ytitle": cls._get_title(columns[1]),
                     }
 
-                elif plottype == 'plot3d':
+                elif plottype == "plot3d":
                     # 3D plot payload
                     plot["payload"] = {
                         "points": {"x": [], "y": [], "z": [], "labels": []},
                         "meshes": [],
                         "xtitle": cls._get_title(columns[0]),
                         "ytitle": cls._get_title(columns[1]),
-                        "ztitle": cls._get_title(columns[2])
+                        "ztitle": cls._get_title(columns[2]),
                     }
                     recent_mesh = None
                     for row in data.iter_rows():
-                        plot["payload"]["points"]["x"].append(row[data.columns.index(columns[0])])
-                        plot["payload"]["points"]["y"].append(row[data.columns.index(columns[1])])
-                        plot["payload"]["points"]["z"].append(row[data.columns.index(columns[2])])
-                        plot["payload"]["points"]["labels"].append(row[data.columns.index(columns[3])])
+                        plot["payload"]["points"]["x"].append(
+                            row[data.columns.index(columns[0])]
+                        )
+                        plot["payload"]["points"]["y"].append(
+                            row[data.columns.index(columns[1])]
+                        )
+                        plot["payload"]["points"]["z"].append(
+                            row[data.columns.index(columns[2])]
+                        )
+                        plot["payload"]["points"]["labels"].append(
+                            row[data.columns.index(columns[3])]
+                        )
 
                         if row[data.columns.index(columns[4])] != recent_mesh:
-                            plot["payload"]["meshes"].append({"x": [], "y": [], "z": []})
+                            plot["payload"]["meshes"].append(
+                                {"x": [], "y": [], "z": []}
+                            )
                         recent_mesh = row[data.columns.index(columns[4])]
 
                         if plot["payload"]["meshes"]:
-                            plot["payload"]["meshes"][-1]["x"].append(row[data.columns.index(columns[0])])
-                            plot["payload"]["meshes"][-1]["y"].append(row[data.columns.index(columns[1])])
-                            plot["payload"]["meshes"][-1]["z"].append(row[data.columns.index(columns[2])])
+                            plot["payload"]["meshes"][-1]["x"].append(
+                                row[data.columns.index(columns[0])]
+                            )
+                            plot["payload"]["meshes"][-1]["y"].append(
+                                row[data.columns.index(columns[1])]
+                            )
+                            plot["payload"]["meshes"][-1]["z"].append(
+                                row[data.columns.index(columns[2])]
+                            )
                 else:
                     raise RuntimeError(f"Error: {plottype} is an unknown plot type")
 
@@ -116,8 +136,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
         else:
             raise ValueError(f"Unsupported format: {fmt}")
 
-        return fmt_export           
-
+        return fmt_export
 
     @classmethod
     def save_df(cls, frame, tag):
@@ -126,22 +145,25 @@ def save_df(cls, frame, tag):
             raise TypeError("Input frame must be a Polars DataFrame")
 
         if tag is None:
-            tag = '-'
+            tag = "-"
 
-        pkl_export = os.path.join(cls.export_dir, f'df{tag}_{cls._gen_basename()}.parquet')
-        frame.write_parquet(pkl_export) # cos pickle is not supported in polars
+        pkl_export = os.path.join(
+            cls.export_dir, "df" + str(tag) + "_" + cls._gen_basename() + ".pkl"
+        )
+        frame.write_parquet(pkl_export)  # cos pickle is not supported in polars
         return pkl_export
 
     @classmethod
     def save_model(cls, skmodel, tag):
-
         import _pickle as cPickle
 
         cls._verify_export_dir()
         if tag is None:
-            tag = '-'
+            tag = "-"
 
-        pkl_export = os.path.join(cls.export_dir, 'ml' + str(tag) + '_' + cls._gen_basename() + ".pkl")
-        with open(pkl_export, 'wb') as f:
+        pkl_export = os.path.join(
+            cls.export_dir, "ml" + str(tag) + "_" + cls._gen_basename() + ".pkl"
+        )
+        with open(pkl_export, "wb") as f:
             cPickle.dump(skmodel, f)
         return pkl_export
diff --git a/mpds_client/retrieve_MPDS.py b/mpds_client/retrieve_MPDS.py
index e8f7dde..5cdf34c 100755
--- a/mpds_client/retrieve_MPDS.py
+++ b/mpds_client/retrieve_MPDS.py
@@ -18,22 +18,26 @@
 try:
     from pymatgen.core.structure import Structure
     from pymatgen.core.lattice import Lattice
+
     use_pmg = True
-except ImportError: pass
+except ImportError:
+    pass
 
 try:
     from ase import Atom
     from ase.spacegroup import crystal
+
     use_ase = True
-except ImportError: pass
+except ImportError:
+    pass
 
 
 if not use_pmg and not use_ase:
     warnings.warn("Crystal structure treatment unavailable")
 
-__author__ = 'Evgeny Blokhin <eb@tilde.pro>'
-__copyright__ = 'Copyright (c) 2020, Evgeny Blokhin, Tilde Materials Informatics'
-__license__ = 'MIT'
+__author__ = "Evgeny Blokhin <eb@tilde.pro>"
+__copyright__ = "Copyright (c) 2020, Evgeny Blokhin, Tilde Materials Informatics"
+__license__ = "MIT"
 
 
 class MPDSDataTypes(object):
@@ -67,46 +71,51 @@ class MPDSDataRetrieval(object):
     *or*
     jsonobj = client.get_data({"formula":"SrTiO3"}, fields={})
     """
+
     default_fields = {
-        'S': [
-            'phase_id',
-            'chemical_formula',
-            'sg_n',
-            'entry',
-            lambda: 'crystal structure',
-            lambda: 'angstrom'
+        "S": [
+            "phase_id",
+            "chemical_formula",
+            "sg_n",
+            "entry",
+            lambda: "crystal structure",
+            lambda: "angstrom",
         ],
-        'P': [
-            'sample.material.phase_id',
-            'sample.material.chemical_formula',
-            'sample.material.condition[0].scalar[0].value',
-            'sample.material.entry',
-            'sample.measurement[0].property.name',
-            'sample.measurement[0].property.units',
-            'sample.measurement[0].property.scalar'
+        "P": [
+            "sample.material.phase_id",
+            "sample.material.chemical_formula",
+            "sample.material.condition[0].scalar[0].value",
+            "sample.material.entry",
+            "sample.measurement[0].property.name",
+            "sample.measurement[0].property.units",
+            "sample.measurement[0].property.scalar",
         ],
-        'C': [
+        "C": [
             lambda: None,
-            'title',
+            "title",
             lambda: None,
-            'entry',
-            lambda: 'phase diagram',
-            'naxes',
-            'arity'
-        ]
+            "entry",
+            lambda: "phase diagram",
+            "naxes",
+            "arity",
+        ],
     }
-    default_titles = ['Phase', 'Formula', 'SG', 'Entry', 'Property', 'Units', 'Value']
+    default_titles = ["Phase", "Formula", "SG", "Entry", "Property", "Units", "Value"]
 
     endpoint = "https://api.mpds.io/v0/download/facet"
 
     pagesize = 1000
-    maxnpages = 120   # one hit may reach 50kB in RAM, consider pagesize*maxnpages*50kB free RAM
-    maxnphases = 1500 # more phases require additional requests
+    maxnpages = (
+        120  # one hit may reach 50kB in RAM, consider pagesize*maxnpages*50kB free RAM
+    )
+    maxnphases = 1500  # more phases require additional requests
     chillouttime = 2  # please, do not use values < 2, because the server may burn out
     verbose = True
     debug = False
 
-    def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None):
+    def __init__(
+        self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None
+    ):
         """
         MPDS API consumer constructor
 
@@ -116,7 +125,7 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=
 
         Returns: None
         """
-        self.api_key = api_key if api_key else os.environ['MPDS_KEY']
+        self.api_key = api_key if api_key else os.environ["MPDS_KEY"]
 
         self.network = httplib2.Http()
 
@@ -126,39 +135,42 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=
         self.debug = debug or self.debug
 
     def _request(self, query, phases=None, page=0, pagesize=None):
-
-        phases = ','.join([str(int(x)) for x in phases]) if phases else ''
-
-        uri = self.endpoint + '?' + urlencode({
-            'q': json.dumps(query),
-            'phases': phases,
-            'page': page,
-            'pagesize': pagesize or self.pagesize,
-            'dtype': self.dtype
-        })
+        phases = ",".join([str(int(x)) for x in phases]) if phases else ""
+
+        uri = (
+            self.endpoint
+            + "?"
+            + urlencode(
+                {
+                    "q": json.dumps(query),
+                    "phases": phases,
+                    "page": page,
+                    "pagesize": pagesize or self.pagesize,
+                    "dtype": self.dtype,
+                }
+            )
+        )
 
         if self.debug:
-            print('curl -XGET -HKey:%s \"%s\"' % (self.api_key, uri))
+            print('curl -XGET -HKey:%s "%s"' % (self.api_key, uri))
 
         response, content = self.network.request(
-            uri=uri,
-            method='GET',
-            headers={'Key': self.api_key}
+            uri=uri, method="GET", headers={"Key": self.api_key}
         )
 
         if response.status != 200:
-            return {'error': content, 'code': response.status}
+            return {"error": content, "code": response.status}
 
         try:
             content = json.loads(content)
         except:
-            return {'error': 'Unreadable data obtained'}
+            return {"error": "Unreadable data obtained"}
 
-        if content.get('error'):
-            return {'error': content['error']}
+        if content.get("error"):
+            return {"error": content["error"]}
 
-        if not content['out']:
-            return {'error': 'No hits', 'code': 204}
+        if not content["out"]:
+            return {"error": "No hits", "code": 204}
 
         return content
 
@@ -171,8 +183,8 @@ def _massage(self, array, fields):
         for item in array:
             filtered = []
 
-            for object_type in ['S', 'P', 'C']:
-                if item['object_type'] == object_type:
+            for object_type in ["S", "P", "C"]:
+                if item["object_type"] == object_type:
                     for expr in fields.get(object_type, []):
                         if isinstance(expr, jmespath.parser.ParsedResult):
                             filtered.append(expr.search(item))
@@ -201,16 +213,16 @@ def count_data(self, search, phases=None, **kwargs):
         """
         result = self._request(search, phases=phases, pagesize=10)
 
-        if result['error']:
-            raise APIError(result['error'], result.get('code', 0))
+        if result["error"]:
+            raise APIError(result["error"], result.get("code", 0))
 
-        if result['npages'] > self.maxnpages:
+        if result["npages"] > self.maxnpages:
             warnings.warn(
-                "\r\nDataset is too big, you may risk to change maxnpages from %s to %s" % \
-                (self.maxnpages, int(math.ceil(result['count']/self.pagesize)))
+                "\r\nDataset is too big, you may risk to change maxnpages from %s to %s"
+                % (self.maxnpages, int(math.ceil(result["count"] / self.pagesize)))
             )
 
-        return result['count']
+        return result["count"]
 
     def get_data(self, search, phases=None, fields=default_fields):
         """
@@ -232,56 +244,68 @@ def get_data(self, search, phases=None, fields=default_fields):
             documented at https://developer.mpds.io/#JSON-schemata
         """
         output = []
-        fields = {
-            key: [jmespath.compile(item) if isinstance(item, str) else item() for item in value]
-            for key, value in fields.items()
-        } if fields else None
+        fields = (
+            {
+                key: [
+                    jmespath.compile(item) if isinstance(item, str) else item()
+                    for item in value
+                ]
+                for key, value in fields.items()
+            }
+            if fields
+            else None
+        )
 
         tot_count = 0
 
         phases = list(set(phases)) if phases else []
 
         if len(phases) > self.maxnphases:
-            all_phases = array_split(phases, int(math.ceil(
-                len(phases)/self.maxnphases
-            )))
-        else: all_phases = [phases]
+            all_phases = array_split(
+                phases, int(math.ceil(len(phases) / self.maxnphases))
+            )
+        else:
+            all_phases = [phases]
 
         nsteps = len(all_phases)
 
         for step, current_phases in enumerate(all_phases, start=1):
-
             counter, hits_count = 0, 0
 
             while True:
-                result = self._request(search, phases=list(current_phases), page=counter)
-                if result['error']:
-                    raise APIError(result['error'], result.get('code', 0))
+                result = self._request(
+                    search, phases=list(current_phases), page=counter
+                )
+                if result["error"]:
+                    raise APIError(result["error"], result.get("code", 0))
 
-                if result['npages'] > self.maxnpages:
+                if result["npages"] > self.maxnpages:
                     raise APIError(
-                        "Too many hits (%s > %s), please, be more specific" % \
-                        (result['count'], self.maxnpages * self.pagesize),
-                        2
+                        "Too many hits (%s > %s), please, be more specific"
+                        % (result["count"], self.maxnpages * self.pagesize),
+                        2,
                     )
-                output.extend(self._massage(result['out'], fields))
+                output.extend(self._massage(result["out"], fields))
 
-                if hits_count and hits_count != result['count']:
-                    raise APIError("API error: hits count has been changed during the query")
+                if hits_count and hits_count != result["count"]:
+                    raise APIError(
+                        "API error: hits count has been changed during the query"
+                    )
 
-                hits_count = result['count']
+                hits_count = result["count"]
 
                 time.sleep(self.chillouttime)
 
-                if counter == result['npages'] - 1:
+                if counter == result["npages"] - 1:
                     break
 
                 counter += 1
 
                 if self.verbose:
-                    sys.stdout.write("\r\t%d%% of step %s from %s" % (
-                        (counter/result['npages']) * 100, step, nsteps)
-                                        )
+                    sys.stdout.write(
+                        "\r\t%d%% of step %s from %s"
+                        % ((counter / result["npages"]) * 100, step, nsteps)
+                    )
                     sys.stdout.flush()
 
             tot_count += hits_count
@@ -311,24 +335,24 @@ def get_dataframe(self, *args, **kwargs):
 
         Returns: (object) Polars dataframe object containing the results
         """
-        columns = kwargs.get('columns')
+        columns = kwargs.get("columns")
         if columns:
-            del kwargs['columns']
+            del kwargs["columns"]
         else:
             columns = self.default_titles
 
         data = self.get_data(*args, **kwargs)
         return pl.DataFrame(data, schema=columns)
 
-    def get_crystals(self, search, phases=None, flavor='pmg', **kwargs):
+    def get_crystals(self, search, phases=None, flavor="pmg", **kwargs):
         search["props"] = "atomic structure"
 
         crystals = []
         for crystal_struct in self.get_data(
-                search,
-                phases,
-                fields={'S':['cell_abc', 'sg_n', 'basis_noneq', 'els_noneq']},
-                **kwargs
+            search,
+            phases,
+            fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]},
+            **kwargs,
         ):
             crobj = self.compile_crystal(crystal_struct, flavor)
             if crobj is not None:
@@ -337,7 +361,7 @@ def get_crystals(self, search, phases=None, flavor='pmg', **kwargs):
         return crystals
 
     @staticmethod
-    def compile_crystal(datarow, flavor='pmg'):
+    def compile_crystal(datarow, flavor="pmg"):
         """
         Helper method for representing the MPDS crystal structures in two flavors:
         either as a Pymatgen Structure object, or as an ASE Atoms object.
@@ -376,20 +400,22 @@ def compile_crystal(datarow, flavor='pmg'):
         if len(datarow) < 4:
             raise ValueError(
                 "Must supply a data row that ends with the entries "
-                "'cell_abc', 'sg_n', 'basis_noneq', 'els_noneq'")
+                "'cell_abc', 'sg_n', 'basis_noneq', 'els_noneq'"
+            )
 
-        cell_abc, sg_n, basis_noneq, els_noneq = \
-            datarow[-4], int(datarow[-3]), datarow[-2], datarow[-1]
+        cell_abc, sg_n, basis_noneq, els_noneq = (
+            datarow[-4],
+            int(datarow[-3]),
+            datarow[-2],
+            datarow[-1],
+        )
 
-        if flavor == 'pmg' and use_pmg:
+        if flavor == "pmg" and use_pmg:
             return Structure.from_spacegroup(
-                sg_n,
-                Lattice.from_parameters(*cell_abc),
-                els_noneq,
-                basis_noneq
+                sg_n, Lattice.from_parameters(*cell_abc), els_noneq, basis_noneq
             )
 
-        elif flavor == 'ase' and use_ase:
+        elif flavor == "ase" and use_ase:
             atom_data = []
 
             for num, i in enumerate(basis_noneq):
@@ -400,8 +426,8 @@ def compile_crystal(datarow, flavor='pmg'):
                 spacegroup=sg_n,
                 cellpar=cell_abc,
                 primitive_cell=True,
-                onduplicates='replace'
+                onduplicates="replace",
             )
 
-        else: raise APIError("Crystal structure treatment unavailable")
-        
+        else:
+            raise APIError("Crystal structure treatment unavailable")
diff --git a/mpds_client/test_export_MPDS.py b/mpds_client/test_export_MPDS.py
index 901087e..fbe8d94 100644
--- a/mpds_client/test_export_MPDS.py
+++ b/mpds_client/test_export_MPDS.py
@@ -1,58 +1,55 @@
 import unittest
 import os
 import polars as pl
-from export_MPDS import MPDSExport 
+from export_MPDS import MPDSExport
 
 
 class TestMPDSExport(unittest.TestCase):
     def test_save_plot_csv(self):
         """Test saving a plot in CSV format."""
-        data = pl.DataFrame({
-            "length": [1.2, 1.5, 1.8, 2.0, 2.2],
-            "occurrence": [10, 15, 8, 20, 12]
-        })
+        data = pl.DataFrame(
+            {"length": [1.2, 1.5, 1.8, 2.0, 2.2], "occurrence": [10, 15, 8, 20, 12]}
+        )
         columns = ["length", "occurrence"]
         plottype = "bar"
 
-        exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='csv')
+        exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="csv")
         self.assertTrue(os.path.isfile(exported_file))
         self.assertTrue(exported_file.endswith(".csv"))
 
     def test_save_plot_json(self):
         """Test saving a plot in JSON format."""
-        data = pl.DataFrame({
-            "length": [1.2, 1.5, 1.8, 2.0, 2.2],
-            "occurrence": [10, 15, 8, 20, 12]
-        })
+        data = pl.DataFrame(
+            {"length": [1.2, 1.5, 1.8, 2.0, 2.2], "occurrence": [10, 15, 8, 20, 12]}
+        )
         columns = ["length", "occurrence"]
         plottype = "bar"
 
-        exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json')
+        exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="json")
         self.assertTrue(os.path.isfile(exported_file))
         self.assertTrue(exported_file.endswith(".json"))
 
     def test_save_plot_3d_json(self):
         """Test saving a 3D plot in JSON format."""
-        data = pl.DataFrame({
-            "x": [1, 2, 3, 4],
-            "y": [5, 6, 7, 8],
-            "z": [9, 10, 11, 12],
-            "labels": ["A", "B", "C", "D"],
-            "meshes_id": [1, 1, 2, 2]
-        })
+        data = pl.DataFrame(
+            {
+                "x": [1, 2, 3, 4],
+                "y": [5, 6, 7, 8],
+                "z": [9, 10, 11, 12],
+                "labels": ["A", "B", "C", "D"],
+                "meshes_id": [1, 1, 2, 2],
+            }
+        )
         columns = ["x", "y", "z", "labels", "meshes_id"]
         plottype = "plot3d"
 
-        exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json')
+        exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="json")
         self.assertTrue(os.path.isfile(exported_file))
         self.assertTrue(exported_file.endswith(".json"))
 
     def test_save_df(self):
         """Test saving Polars DataFrame."""
-        data = pl.DataFrame({
-            "column1": [1, 2, 3],
-            "column2": [4, 5, 6]
-        })
+        data = pl.DataFrame({"column1": [1, 2, 3], "column2": [4, 5, 6]})
         tag = "test"
 
         exported_file = MPDSExport.save_df(data, tag)
diff --git a/mpds_client/test_retrieve_MPDS.py b/mpds_client/test_retrieve_MPDS.py
index 3cecae7..898354b 100755
--- a/mpds_client/test_retrieve_MPDS.py
+++ b/mpds_client/test_retrieve_MPDS.py
@@ -1,5 +1,5 @@
 import unittest
-#import warnings
+# import warnings
 
 import polars as pl
 
@@ -15,22 +15,23 @@
 class MPDSDataRetrievalTest(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
-        #warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*<ssl.SSLSocket.*>")
+        # warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*<ssl.SSLSocket.*>")
 
         network = httplib2.Http()
-        response, content = network.request('https://developer.mpds.io/mpds.schema.json')
+        response, content = network.request(
+            "https://developer.mpds.io/mpds.schema.json"
+        )
         assert response.status == 200
 
         cls.schema = json.loads(content)
         Draft4Validator.check_schema(cls.schema)
 
     def test_valid_answer(self):
-
         query = {
             "elements": "K-Ag",
             "classes": "iodide",
             "props": "heat capacity",
-            "lattices": "cubic"
+            "lattices": "cubic",
         }
 
         client = MPDSDataRetrieval()
@@ -40,30 +41,28 @@ def test_valid_answer(self):
             validate(answer, self.schema)
         except ValidationError as e:
             self.fail(
-                "The item: \r\n\r\n %s \r\n\r\n has an issue: \r\n\r\n %s" % (
-                    e.instance, e.context
-                )
+                "The item: \r\n\r\n %s \r\n\r\n has an issue: \r\n\r\n %s"
+                % (e.instance, e.context)
             )
 
     def test_crystal_structure(self):
-
         query = {
             "elements": "Ti-O",
             "classes": "binary",
             "props": "atomic structure",
-            "sgs": 136
+            "sgs": 136,
         }
 
         client = MPDSDataRetrieval()
         ntot = client.count_data(query)
         self.assertTrue(150 < ntot < 175)
 
-        for crystal_struct in client.get_data(query, fields={
-            'S': ['cell_abc', 'sg_n', 'basis_noneq', 'els_noneq']}):
-
+        for crystal_struct in client.get_data(
+            query, fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]}
+        ):
             self.assertEqual(crystal_struct[1], 136)
 
-            ase_obj = MPDSDataRetrieval.compile_crystal(crystal_struct, 'ase')
+            ase_obj = MPDSDataRetrieval.compile_crystal(crystal_struct, "ase")
             if ase_obj:
                 self.assertEqual(len(ase_obj), 6)
 
@@ -72,20 +71,22 @@ def test_get_crystals(self):
             "elements": "Ti-O",
             "classes": "binary",
             "props": "atomic structure",
-            "sgs": 136
+            "sgs": 136,
         }
         client = MPDSDataRetrieval()
         ntot = client.count_data(query)
         logging.debug(f"Value of ntot: {ntot}")
         self.assertTrue(150 < ntot < 175)
 
-        crystals = client.get_crystals(query, flavor='ase')
+        crystals = client.get_crystals(query, flavor="ase")
         for crystal in crystals:
             self.assertIsNotNone(crystal)
 
         # now try getting the crystal from the phase_id(s)
-        phase_ids = {_[0] for _ in client.get_data(query, fields={'S': ['phase_id']})}
-        crystals_from_phase_ids = client.get_crystals(query, phases=phase_ids, flavor='ase')
+        phase_ids = {_[0] for _ in client.get_data(query, fields={"S": ["phase_id"]})}
+        crystals_from_phase_ids = client.get_crystals(
+            query, phases=phase_ids, flavor="ase"
+        )
 
         self.assertEqual(len(crystals), len(crystals_from_phase_ids))
 
@@ -95,15 +96,11 @@ def test_retrieval_of_phases(self):
         in two ways:
         maxnphases = changed and maxnphases = default
         """
-        query_a = {
-            "elements": "O",
-            "classes": "binary",
-            "props": "band gap"
-        }
+        query_a = {"elements": "O", "classes": "binary", "props": "band gap"}
         query_b = {
             "elements": "O",
             "classes": "binary",
-            "props": "isothermal bulk modulus"
+            "props": "isothermal bulk modulus",
         }
 
         client_one = MPDSDataRetrieval()
@@ -111,24 +108,28 @@ def test_retrieval_of_phases(self):
 
         answer_one = client_one.get_dataframe(
             query_a,
-            fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']},
-            columns=['Phid', 'Object']
+            fields={
+                "P": ["sample.material.phase_id", "sample.material.chemical_formula"]
+            },
+            columns=["Phid", "Object"],
         )
-        if not(isinstance(answer_one, pl.DataFrame)):
+        if not (isinstance(answer_one, pl.DataFrame)):
             print(type(answer_one))
             raise ValueError("answer_one is not a Polars DataFrame", type(answer_one))
 
-        answer_one = answer_one.filter(pl.col('Phid').is_not_null())
-        answer_one = answer_one.with_columns(pl.col('Phid').cast(pl.Int32))
-        phases_one = answer_one['Phid'].to_list()
-        
+        answer_one = answer_one.filter(pl.col("Phid").is_not_null())
+        answer_one = answer_one.with_columns(pl.col("Phid").cast(pl.Int32))
+        phases_one = answer_one["Phid"].to_list()
+
         self.assertTrue(len(phases_one) > client_one.maxnphases)
 
         result_one = client_one.get_dataframe(
             query_b,
-            fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']},
-            columns=['Phid', 'Object'],
-            phases=phases_one
+            fields={
+                "P": ["sample.material.phase_id", "sample.material.chemical_formula"]
+            },
+            columns=["Phid", "Object"],
+            phases=phases_one,
         )
 
         client_two = MPDSDataRetrieval()
@@ -136,23 +137,29 @@ def test_retrieval_of_phases(self):
 
         answer_two = client_two.get_dataframe(
             query_a,
-            fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']},
-            columns=['Phid', 'Object']
+            fields={
+                "P": ["sample.material.phase_id", "sample.material.chemical_formula"]
+            },
+            columns=["Phid", "Object"],
         )
-        if not(isinstance(answer_one, pl.DataFrame)):
+        if not (isinstance(answer_one, pl.DataFrame)):
             print(type(answer_two))
-            raise ValueError("answer_one is not a Polars DataFrame, is", type(answer_two))
-        
-        answer_two = answer_two.filter(pl.col('Phid').is_not_null())
-        phases_two = answer_two['Phid'].cast(pl.Int32).to_list()
+            raise ValueError(
+                "answer_one is not a Polars DataFrame, is", type(answer_two)
+            )
+
+        answer_two = answer_two.filter(pl.col("Phid").is_not_null())
+        phases_two = answer_two["Phid"].cast(pl.Int32).to_list()
 
         self.assertTrue(len(phases_two) < client_two.maxnphases)
 
         result_two = client_two.get_dataframe(
             query_b,
-            fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']},
-            columns=['Phid', 'Object'],
-            phases=phases_two
+            fields={
+                "P": ["sample.material.phase_id", "sample.material.chemical_formula"]
+            },
+            columns=["Phid", "Object"],
+            phases=phases_two,
         )
 
         self.assertEqual(len(result_one), len(result_two))
@@ -160,11 +167,14 @@ def test_retrieval_of_phases(self):
         # check equality of result_one and result_two
         merge = pl.concat([result_one, result_two])
         merge = merge.with_columns(pl.Series("index", range(len(merge))))
-        merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg(pl.len())
+        merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg(
+            pl.len()
+        )
         idx = [x[0] for x in merge_gpby.iter_rows() if x[-1] == 1]
 
         self.assertTrue(merge.filter(pl.col("index").is_in(idx)).is_empty())
 
-if __name__ == "__main__": 
+
+if __name__ == "__main__":
     logging.basicConfig(level=logging.DEBUG)
     unittest.main()