Skip to content

Commit

Permalink
Add "allow_unmapped" to Ax SQA objects for SQA 2.0 forward compatibil…
Browse files Browse the repository at this point in the history
…ity (#3008)

Summary:
Pull Request resolved: #3008

Pull Request resolved: #2748

T163607006 for more context

OSS User trying to use Ax encountered this SQA error when using version 2.0:
```
ArgumentError: Type annotation for "SQAGeneratorRun.arms" can't be correctly interpreted for Annotated Declarative Table form.
ORM annotations should normally make use of the ``Mapped[]`` generic type, or other ORM-compatible generic type, as a container for the actual type, which indicates the intent that the attribute is mapped. Class variables that are not intended to be mapped by the ORM should use ClassVar[].

To allow Annotated Declarative to disregard legacy annotations which don't use Mapped[] to pass,
set "__allow_unmapped__ = True" on the class or a superclass this class. (Background on this error at: https://sqlalche.me/e/20/zlpr)
```
Currently SQA 1.4 is the only supported version internally.

This change follows the suggestion of the error to set "__allow_unmapped__" equal to true, which is also suggested in the SQL alchemy wiki
https://docs.sqlalchemy.org/en/20/changelog/migration_20.html?fbclid=IwZXh0bgNhZW0CMTEAAR083E0mVk0DkKTo9R1AimFUsoZ4iV2ei1BVKFYmH4iQVrMqcS6F6fv7ZUw_aem_S3WfZmTwJIdpYJkQDo2icQ#migration-to-2-0-step-six-add-allow-unmapped-to-explicitly-typed-orm-models

This should fix the issue encountered when using SQA 2.0 in OSS.

Errors came up:
```
ERROR ax/core/tests/test_experiment.py - sqlalchemy.exc.ArgumentError: Could not interpret annotation list[SQAMetric].
Check that it uses names that are correctly imported at the module level. See chained stack trace for more hints.

ERROR ax/service/tests/test_ax_client.py - sqlalchemy.exc.InvalidRequestError: Table 'parameter_v2' is already defined for this MetaData instance.
Specify 'extend_existing=True' to redefine options and columns on an existing Table object.
```

Fixing these led to another round of errors
```
FAILED ax/core/tests/test_experiment.py::ExperimentTest::test_clone_with - sqlalchemy.exc.ArgumentError: Strings are not accepted for attribute names in loader options; please use class-bound attributes directly.
FAILED ax/service/tests/test_ax_client.py::TestAxClient::test_db_write_failure_on_create_experiment - ValueError: `db_settings` argument should be of type ax.storage.sqa_store.(Got: DBSettings(creator=None, decoder=<ax.storage.sqa_store.decoder.Decoder object at 0x7f0b971f5630>, encoder=<ax.storage.sqa_store.encoder.Encoder object at 0x7f0b971f6fe0>, url=None) of type <class 'ax.storage.sqa_store.structs.DBSettings'>. structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy installed in your environment (can be installed through pip).
FAILED ax/service/tests/test_ax_client.py::TestAxClient::test_save_and_load_generation_strategy - ValueError: `db_settings` argument should be of type ax.storage.sqa_store.(Got: DBSettings(creator=None, decoder=<ax.storage.sqa_store.decoder.Decoder object at 0x7f0ab0687d30>, encoder=<ax.storage.sqa_store.encoder.Encoder object at 0x7f0ab0685630>, url=None) of type <class 'ax.storage.sqa_store.structs.DBSettings'>. structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy installed in your environment (can be installed through pip).
FAILED ax/service/tests/test_ax_client.py::TestAxClient::test_sqa_storage - ValueError: `db_settings` argument should be of type ax.storage.sqa_store.(Got: DBSettings(creator=None, decoder=<ax.storage.sqa_store.decoder.Decoder object at 0x7f0b7f73b460>, encoder=<ax.storage.sqa_store.encoder.Encoder object at 0x7f0b7f7397b0>, url=None) of type <class 'ax.storage.sqa_store.structs.DBSettings'>. structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy installed in your environment (can be installed through pip).
```

Reviewed By: Balandat, mgrange1998

Differential Revision: D62261700

fbshipit-source-id: 6a9968b07cfe7855652507f2d29f64a7ff80826a
  • Loading branch information
paschai authored and facebook-github-bot committed Nov 1, 2024
1 parent a9a9a7c commit 633460b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
2 changes: 2 additions & 0 deletions ax/storage/sqa_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
class SQABase:
"""Metaclass for SQLAlchemy classes corresponding to core Ax classes."""

__allow_unmapped__ = True
__table_args__ = {"extend_existing": True}
pass


Expand Down
44 changes: 22 additions & 22 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from datetime import datetime
from decimal import Decimal
from typing import Any
from typing import Any, List

from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import LifecycleStage
Expand Down Expand Up @@ -79,10 +79,10 @@ class SQAParameter(Base):
upper: Column[Decimal | None] = Column(Float)

# Attributes for Choice Parameters
choice_values: Column[list[TParamValue] | None] = Column(JSONEncodedList)
choice_values: Column[List[TParamValue] | None] = Column(JSONEncodedList)
is_ordered: Column[bool | None] = Column(Boolean)
is_task: Column[bool | None] = Column(Boolean)
dependents: Column[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject)
dependents: Column[dict[TParamValue, List[str]] | None] = Column(JSONEncodedObject)

# Attributes for Fixed Parameters
fixed_value: Column[TParamValue | None] = Column(JSONEncodedObject)
Expand Down Expand Up @@ -134,7 +134,7 @@ class SQAMetric(Base):
# of Multi/Scalarized Objective contains all children of the parent metric
# join_depth argument: used for loading self-referential relationships
# https://docs.sqlalchemy.org/en/13/orm/self_referential.html#configuring-self-referential-eager-loading
scalarized_objective_children_metrics: list[SQAMetric] = relationship(
scalarized_objective_children_metrics: List["SQAMetric"] = relationship(
"SQAMetric",
cascade="all, delete-orphan",
lazy=True,
Expand All @@ -146,7 +146,7 @@ class SQAMetric(Base):
scalarized_outcome_constraint_id: Column[int | None] = Column(
Integer, ForeignKey("metric_v2.id")
)
scalarized_outcome_constraint_children_metrics: list[SQAMetric] = relationship(
scalarized_outcome_constraint_children_metrics: List["SQAMetric"] = relationship(
"SQAMetric",
cascade="all, delete-orphan",
lazy=True,
Expand Down Expand Up @@ -213,19 +213,19 @@ class SQAGeneratorRun(Base):
# relationships
# Use selectin loading for collections to prevent idle timeout errors
# (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading)
arms: list[SQAArm] = relationship(
arms: List[SQAArm] = relationship(
"SQAArm",
cascade="all, delete-orphan",
lazy="selectin",
order_by=lambda: SQAArm.id,
)
metrics: list[SQAMetric] = relationship(
metrics: List[SQAMetric] = relationship(
"SQAMetric", cascade="all, delete-orphan", lazy="selectin"
)
parameters: list[SQAParameter] = relationship(
parameters: List[SQAParameter] = relationship(
"SQAParameter", cascade="all, delete-orphan", lazy="selectin"
)
parameter_constraints: list[SQAParameterConstraint] = relationship(
parameter_constraints: List[SQAParameterConstraint] = relationship(
"SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin"
)

Expand Down Expand Up @@ -267,15 +267,15 @@ class SQAGenerationStrategy(Base):

id: Column[int] = Column(Integer, primary_key=True)
name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
steps: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False)
steps: Column[List[dict[str, Any]]] = Column(JSONEncodedList, nullable=False)
curr_index: Column[int | None] = Column(Integer, nullable=True)
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
nodes: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=True)
nodes: Column[List[dict[str, Any]]] = Column(JSONEncodedList, nullable=True)
curr_node_name: Column[str | None] = Column(
String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True
)

generator_runs: list[SQAGeneratorRun] = relationship(
generator_runs: List[SQAGeneratorRun] = relationship(
"SQAGeneratorRun",
cascade="all, delete-orphan",
lazy="selectin",
Expand Down Expand Up @@ -321,10 +321,10 @@ class SQATrial(Base):
# a child, the old one will be deleted.
# Use selectin loading for collections to prevent idle timeout errors
# (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading)
abandoned_arms: list[SQAAbandonedArm] = relationship(
abandoned_arms: List[SQAAbandonedArm] = relationship(
"SQAAbandonedArm", cascade="all, delete-orphan", lazy="selectin"
)
generator_runs: list[SQAGeneratorRun] = relationship(
generator_runs: List[SQAGeneratorRun] = relationship(
"SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin"
)
runner: SQARunner = relationship(
Expand Down Expand Up @@ -371,7 +371,7 @@ class SQAExperiment(Base):
# pyre-fixme[8]: Incompatible attribute type [8]: Attribute
# `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has
# type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]`
auxiliary_experiments_by_purpose: dict[str, list[str]] | None = Column(
auxiliary_experiments_by_purpose: dict[str, List[str]] | None = Column(
JSONEncodedTextDict, nullable=True, default={}
)

Expand All @@ -381,22 +381,22 @@ class SQAExperiment(Base):
# a child, the old one will be deleted.
# Use selectin loading for collections to prevent idle timeout errors
# (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading)
data: list[SQAData] = relationship(
data: List[SQAData] = relationship(
"SQAData", cascade="all, delete-orphan", lazy="selectin"
)
metrics: list[SQAMetric] = relationship(
metrics: List[SQAMetric] = relationship(
"SQAMetric", cascade="all, delete-orphan", lazy="selectin"
)
parameters: list[SQAParameter] = relationship(
parameters: List[SQAParameter] = relationship(
"SQAParameter", cascade="all, delete-orphan", lazy="selectin"
)
parameter_constraints: list[SQAParameterConstraint] = relationship(
parameter_constraints: List[SQAParameterConstraint] = relationship(
"SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin"
)
runners: list[SQARunner] = relationship(
runners: List[SQARunner] = relationship(
"SQARunner", cascade="all, delete-orphan", lazy=False
)
trials: list[SQATrial] = relationship(
trials: List[SQATrial] = relationship(
"SQATrial", cascade="all, delete-orphan", lazy="selectin"
)
generation_strategy: SQAGenerationStrategy | None = relationship(
Expand All @@ -405,6 +405,6 @@ class SQAExperiment(Base):
uselist=False,
lazy=True,
)
analysis_cards: list[SQAAnalysisCard] = relationship(
analysis_cards: List[SQAAnalysisCard] = relationship(
"SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin"
)

0 comments on commit 633460b

Please sign in to comment.