Skip to content

Commit

Permalink
General repo and tooling cleanup (#839)
Browse files Browse the repository at this point in the history
* Replace setup.py with pyproject.toml

* Add pytest config to pyproject

* Add more to pytest section

* Add concurrency setting to testing wflow

* Add ruff section to pyproject

* Don't error out on dep warnings

* Upgrade ruff py version

* mypy and black docs

* Add pre-commit custom message

* FIx py versions

* Linting

* Remove mypy

* Ruff add show fixes

* Change ruff pre-commit repo

* Ensure import annotations

* Fix type hinting

* Add docstring ignores

* More docstring linting

* Fix type hint in utils

* Fix typing for python < 3.10

* Linting

* Revert typing and exclude settings

* Revert most typing changes

* Factor out common testing function

* Fix materials float test

* Linting
  • Loading branch information
Jason Munro authored Aug 24, 2023
1 parent f8cfd82 commit 57d8acf
Show file tree
Hide file tree
Showing 78 changed files with 1,168 additions and 1,984 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ on:
pull_request:
branches: [main]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
strategy:
Expand Down
20 changes: 14 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
default_stages: [commit]

default_install_hook_types: [pre-commit, commit-msg]

ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"

repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.261
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.284
hooks:
- id: ruff
args: [--fix, --ignore, "D,E501"]
args: [--fix, --show-fixes]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/asottile/blacken-docs
rev: "1.15.0"
hooks:
- id: black-jupyter
- id: blacken-docs
additional_dependencies: [black>=23.7.0]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand Down
2 changes: 2 additions & 0 deletions mp_api/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Primary MAPI module."""
from __future__ import annotations

import os
from importlib.metadata import PackageNotFoundError, version

Expand Down
2 changes: 2 additions & 0 deletions mp_api/client/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from __future__ import annotations

from .client import BaseRester, MPRestError
from .settings import MAPIClientSettings
92 changes: 48 additions & 44 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
API v3 to enable the creation of data structures and pymatgen objects using
Materials Project data.
"""
from __future__ import annotations

import gzip
import itertools
Expand All @@ -14,7 +15,7 @@
from json import JSONDecodeError
from math import ceil
from os import environ
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Generic, TypeVar
from urllib.parse import quote, urljoin

import requests
Expand Down Expand Up @@ -52,18 +53,18 @@
class BaseRester(Generic[T]):
"""Base client class with core stubs."""

suffix: Optional[str] = None
suffix: str | None = None
document_model: BaseModel = None # type: ignore
supports_versions: bool = False
primary_key: str = "material_id"

def __init__(
self,
api_key: Union[str, None] = None,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
include_user_agent: bool = True,
session: Optional[requests.Session] = None,
s3_resource: Optional[Any] = None,
session: requests.Session | None = None,
s3_resource: Any | None = None,
debug: bool = False,
monty_decode: bool = True,
use_document_model: bool = True,
Expand Down Expand Up @@ -191,11 +192,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover

def _post_resource(
self,
body: Dict = None,
params: Optional[Dict] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
) -> Dict:
body: dict = None,
params: dict | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
) -> dict:
"""Post data to the endpoint for a Resource.
Arguments:
Expand Down Expand Up @@ -261,11 +262,11 @@ def _post_resource(

def _patch_resource(
self,
body: Dict = None,
params: Optional[Dict] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
) -> Dict:
body: dict = None,
params: dict | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
) -> dict:
"""Patch data to the endpoint for a Resource.
Arguments:
Expand Down Expand Up @@ -330,7 +331,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
"""Query Materials Project AWS open data s3 buckets
"""Query Materials Project AWS open data s3 buckets.
Args:
bucket (str): Materials project bucket name
Expand All @@ -340,7 +341,6 @@ def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
Returns:
dict: MontyDecoded data
"""

ref = self.s3_resource.Object(bucket, f"{prefix}/{key}.json.gz") # type: ignore
bytes = ref.get()["Body"] # type: ignore

Expand All @@ -352,15 +352,15 @@ def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:

def _query_resource(
self,
criteria: Optional[Dict] = None,
fields: Optional[List[str]] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
parallel_param: Optional[str] = None,
num_chunks: Optional[int] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
) -> Dict:
criteria: dict | None = None,
fields: list[str] | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
parallel_param: str | None = None,
num_chunks: int | None = None,
chunk_size: int | None = None,
timeout: int | None = None,
) -> dict:
"""Query the endpoint for a Resource containing a list of documents
and meta information about pagination and total document count.
Expand Down Expand Up @@ -429,7 +429,7 @@ def _submit_requests(
num_chunks=None,
chunk_size=None,
timeout=None,
) -> Dict:
) -> dict:
"""Handle submitting requests. Parallel requests supported if possible.
Parallelization will occur either over the largest list of supported
query parameters used and/or over pagination.
Expand Down Expand Up @@ -712,7 +712,7 @@ def _submit_requests(
def _multi_thread(
self,
use_document_model: bool,
params_list: List[dict],
params_list: list[dict],
progress_bar: tqdm = None,
timeout: int = None,
):
Expand Down Expand Up @@ -788,7 +788,7 @@ def _submit_request_and_process(
params: dict,
use_document_model: bool,
timeout: int = None,
) -> Tuple[Dict, int]:
) -> tuple[dict, int]:
"""Submits GET request and handles the response.
Arguments:
Expand Down Expand Up @@ -936,12 +936,12 @@ def new_dict(self, *args, **kwargs):

def _query_resource_data(
self,
criteria: Optional[Dict] = None,
fields: Optional[List[str]] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
timeout: Optional[int] = None,
) -> Union[List[T], List[Dict]]:
criteria: dict | None = None,
fields: list[str] | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
timeout: int | None = None,
) -> list[T] | list[dict]:
"""Query the endpoint for a list of documents without associated meta information. Only
returns a single page of results.
Expand All @@ -967,7 +967,7 @@ def _query_resource_data(
def get_data_by_id(
self,
document_id: str,
fields: Optional[List[str]] = None,
fields: list[str] | None = None,
) -> T:
"""Query the endpoint for a single document.
Expand Down Expand Up @@ -1039,12 +1039,12 @@ def get_data_by_id(

def _search(
self,
num_chunks: Optional[int] = None,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
fields: Optional[List[str]] = None,
fields: list[str] | None = None,
**kwargs,
) -> Union[List[T], List[Dict]]:
) -> list[T] | list[dict]:
"""A generic search method to retrieve documents matching specific parameters.
Arguments:
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def _get_all_documents(
fields=None,
chunk_size=1000,
num_chunks=None,
) -> Union[List[T], List[Dict]]:
) -> list[T] | list[dict]:
"""Iterates over pages until all documents are retrieved. Displays
progress using tqdm. This method is designed to give a common
implementation for the search_* methods on various endpoints. See
Expand Down Expand Up @@ -1124,10 +1124,14 @@ def _get_all_documents(

return results["data"]

def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
def count(self, criteria: dict | None = None) -> int | str:
"""Return a count of total documents.
:param criteria: As in .query()
:return:
Args:
criteria (dict | None): As in .query(). Defaults to None
Returns:
(int | str): Count of total results, or string indicating error
"""
try:
criteria = criteria or {}
Expand All @@ -1145,7 +1149,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
return "Problem getting count"

@property
def available_fields(self) -> List[str]:
def available_fields(self) -> list[str]:
if self.document_model is None:
return ["Unknown fields."]
return list(self.document_model.schema()["properties"].keys()) # type: ignore
Expand Down
20 changes: 13 additions & 7 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import re
from typing import List, Optional, Type, get_args
from typing import get_args

from monty.json import MSONable
from pydantic import BaseModel
from pydantic.schema import get_flat_models_from_model
from pydantic.utils import lenient_issubclass


def validate_ids(id_list: List[str]):
def validate_ids(id_list: list[str]):
"""Function to validate material and task IDs.
Args:
Expand All @@ -29,8 +31,8 @@ def validate_ids(id_list: List[str]):


def api_sanitize(
pydantic_model: Type[BaseModel],
fields_to_leave: Optional[List[str]] = None,
pydantic_model: BaseModel,
fields_to_leave: list[str] | None = None,
allow_dict_msonable=False,
):
"""Function to clean up pydantic models for the API by:
Expand All @@ -40,13 +42,17 @@ def api_sanitize(
WARNING: This works in place, so it mutates the model and all sub-models
Args:
fields_to_leave: list of strings for model fields as "model__name__.field"
pydantic_model (BaseModel): Pydantic model to alter
fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
Defaults to None.
allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
Defaults to False
"""
models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]
] # type: list[BaseModel]

fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
Expand Down Expand Up @@ -77,7 +83,7 @@ def api_sanitize(
return pydantic_model


def allow_msonable_dict(monty_cls: Type[MSONable]):
def allow_msonable_dict(monty_cls: type[MSONable]):
"""Patch Monty to allow for dict values for MSONable."""

def validate_monty(cls, v):
Expand Down
Loading

0 comments on commit 57d8acf

Please sign in to comment.