Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Jun 14, 2023
1 parent 8f0358a commit 928a696
Showing 1 changed file with 69 additions and 25 deletions.
94 changes: 69 additions & 25 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import base64
import warnings
import zlib
from collections import defaultdict
from typing import List, Optional, Tuple, Union

import msgpack
from emmet.core.electronic_structure import (
BSPathType,
DOSProjectionType,
ElectronicStructureDoc,
)
from monty.serialization import MontyDecoder
from pymatgen.analysis.magnetism.analyzer import Ordering
from pymatgen.core.periodic_table import Element
from pymatgen.electronic_structure.core import OrbitalType, Spin
Expand Down Expand Up @@ -109,7 +105,9 @@ def search(
query_params.update({"exclude_elements": ",".join(exclude_elements)})

if band_gap:
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -120,7 +118,9 @@ def search(
if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update({"nelements_min": num_elements[0], "nelements_max": num_elements[1]})
query_params.update(
{"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
)

if is_gap_direct is not None:
query_params.update({"is_gap_direct": is_gap_direct})
Expand All @@ -129,9 +129,15 @@ def search(
query_params.update({"is_metal": is_metal})

if sort_fields:
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)

query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super()._search(
num_chunks=num_chunks,
Expand Down Expand Up @@ -195,7 +201,9 @@ def search(
query_params["path_type"] = path_type.value

if band_gap:
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -210,9 +218,15 @@ def search(
query_params.update({"is_metal": is_metal})

if sort_fields:
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)

query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super()._search(
num_chunks=num_chunks,
Expand All @@ -232,7 +246,9 @@ def get_bandstructure_from_task_id(self, task_id: str):
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""

result = self._query_open_data(bucket="materialsproject-parsed", prefix="bandstructures", key=task_id)
result = self._query_open_data(
bucket="materialsproject-parsed", prefix="bandstructures", key=task_id
)

if result.get("data", None) is not None:
return result["data"]
Expand All @@ -255,32 +271,46 @@ def get_bandstructure_from_material_id(
Returns:
bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
"""
es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key)
es_rester = ElectronicStructureRester(
endpoint=self.base_endpoint, api_key=self.api_key
)

if line_mode:
bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["bandstructure"]).bandstructure
bs_data = es_rester.get_data_by_id(
document_id=material_id, fields=["bandstructure"]
).bandstructure

if bs_data is None:
raise MPRestError(f"No {path_type.value} band structure data found for {material_id}")
raise MPRestError(
f"No {path_type.value} band structure data found for {material_id}"
)
else:
bs_data = bs_data.dict()

if bs_data.get(path_type.value, None):
bs_task_id = bs_data[path_type.value]["task_id"]
else:
raise MPRestError(f"No {path_type.value} band structure data found for {material_id}")
raise MPRestError(
f"No {path_type.value} band structure data found for {material_id}"
)
else:
bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dos
bs_data = es_rester.get_data_by_id(
document_id=material_id, fields=["dos"]
).dos

if bs_data is None:
raise MPRestError(f"No uniform band structure data found for {material_id}")
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
else:
bs_data = bs_data.dict()

if bs_data.get("total", None):
bs_task_id = bs_data["total"]["1"]["task_id"]
else:
raise MPRestError(f"No uniform band structure data found for {material_id}")
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)

bs_obj = self.get_bandstructure_from_task_id(bs_task_id)

Expand Down Expand Up @@ -352,7 +382,9 @@ def search(
query_params["orbital"] = orbital.value

if band_gap:
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -361,9 +393,15 @@ def search(
query_params.update({"magnetic_ordering": magnetic_ordering.value})

if sort_fields:
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)

query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super()._search(
num_chunks=num_chunks,
Expand All @@ -382,7 +420,9 @@ def get_dos_from_task_id(self, task_id: str):
Returns:
bandstructure (CompleteDos): CompleteDos object
"""
result = self._query_open_data(bucket="materialsproject-parsed", prefix="dos", key=task_id)
result = self._query_open_data(
bucket="materialsproject-parsed", prefix="dos", key=task_id
)

if result.get("data", None) is not None:
return result["data"]
Expand All @@ -398,9 +438,13 @@ def get_dos_from_material_id(self, material_id: str):
Returns:
dos (CompleteDos): CompleteDos object
"""
es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key)
es_rester = ElectronicStructureRester(
endpoint=self.base_endpoint, api_key=self.api_key
)

dos_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dict()
dos_data = es_rester.get_data_by_id(
document_id=material_id, fields=["dos"]
).dict()

if dos_data["dos"]:
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
Expand Down

0 comments on commit 928a696

Please sign in to comment.