diff --git a/bionty/base/entities/_gene.py b/bionty/base/entities/_gene.py index 4543ee2..dda60c7 100644 --- a/bionty/base/entities/_gene.py +++ b/bionty/base/entities/_gene.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Iterable, Literal, Optional +from typing import Iterable, Literal import pandas as pd from lamin_utils import logger @@ -86,23 +86,33 @@ def map_legacy_ids( class EnsemblGene: - def __init__(self, organism: str, version: str) -> None: + def __init__( + self, + organism: str, + version: str, + taxa: Literal[ + "vertebrates", "bacteria", "fungi", "metazoa", "plants", "all" + ] = "vertebrates", + ) -> None: """Ensembl Gene mysql. Args: - organism: a bionty.Organism object - version: name of the ensembl DB version, e.g. "release-110" + organism: Name of the organism + version: Name of the ensembl DB version, e.g. "release-110" """ self._import() import mysql.connector as sql from sqlalchemy import create_engine self._organism = ( - Organism(version=version).lookup().dict().get(organism) # type:ignore - ) - self._url = ( - f"mysql+mysqldb://anonymous:@ensembldb.ensembl.org/{self._organism.core_db}" + Organism(version=version, taxa=taxa).lookup().dict().get(organism) # type:ignore ) + # vertebrates and plants use different ports + if taxa == "plants": + port = 4157 + else: + port = 3306 + self._url = f"mysql+mysqldb://anonymous:@ensembldb.ensembl.org:{port}/{self._organism.core_db}" self._engine = create_engine(url=self._url) def _import(self): @@ -221,8 +231,10 @@ def add_external_db_column(df: pd.DataFrame, ext_db: str, df_col: str): df_res = df_res[~df_res["ensembl_gene_id"].isna()] # if stable_id is not ensembl_gene_id, keep a stable_id column - if not any(df_res["ensembl_gene_id"].str.startswith("ENS")): - logger.warning("no ensembl_gene_id found, writing to table_id column.") + if not all(df_res["ensembl_gene_id"].str.startswith("ENS")): + logger.warning( + "ensembl_gene_id column not all ENS-prefixed, writing to stable_id column." + ) df_res.insert(0, "stable_id", df_res.pop("ensembl_gene_id")) df_res = df_res.sort_values("stable_id").reset_index(drop=True) else: diff --git a/bionty/base/entities/_organism.py b/bionty/base/entities/_organism.py index 6f9c154..ab61a8f 100644 --- a/bionty/base/entities/_organism.py +++ b/bionty/base/entities/_organism.py @@ -22,9 +22,7 @@ class Organism(PublicOntology): def __init__( self, - organism: Literal[ - "vertebrates", "bacteria", "fungi", "metazoa", "plants", "all" - ] + taxa: Literal["vertebrates", "bacteria", "fungi", "metazoa", "plants", "all"] | None = None, source: Literal["ensembl", "ncbitaxon"] | None = None, version: Literal[ @@ -39,7 +37,10 @@ def __init__( | None = None, **kwargs, ): - super().__init__(organism=organism, source=source, version=version, **kwargs) + # To support the organism kwarg being passed in getattr access in other parts of the code + if kwargs.get("organism") is not None: + taxa = kwargs.pop("organism") + super().__init__(organism=taxa, source=source, version=version, **kwargs) def _load_df(self) -> pd.DataFrame: if self.source == "ensembl":