-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more metadata and try to guess if relationship is iterable #74
base: master
Are you sure you want to change the base?
Changes from all commits
7bd2ace
6317b79
9470f5c
c268644
ee7baf3
25c3ef3
d5f2ede
605a997
79d0504
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,10 @@ | |
Plugin, FunctionContext, ClassDefContext, DynamicClassDefContext, | ||
SemanticAnalyzerPluginInterface | ||
) | ||
from mypy.plugins.common import add_method | ||
from mypy.plugins.common import add_method, _get_argument | ||
from mypy.nodes import ( | ||
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, | ||
Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr | ||
Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr, AssignmentStmt, CallExpr, MemberExpr | ||
) | ||
from mypy.types import ( | ||
UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType | ||
|
@@ -25,6 +25,7 @@ | |
COLUMN_ELEMENT_NAME = 'sqlalchemy.sql.elements.ColumnElement' # type: Final | ||
GROUPING_NAME = 'sqlalchemy.sql.elements.Grouping' # type: Final | ||
RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final | ||
FOREIGN_KEY_NAME = 'sqlalchemy.sql.schema.ForeignKey' # type: Final | ||
|
||
|
||
def is_declarative(info: TypeInfo) -> bool: | ||
|
@@ -110,6 +111,53 @@ def add_model_init_hook(ctx: ClassDefContext) -> None: | |
add_method(ctx, '__init__', [kw_arg], NoneTyp()) | ||
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True | ||
|
||
for stmt in ctx.cls.defs.body: | ||
if not (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1 and isinstance(stmt.lvalues[0], NameExpr)): | ||
continue | ||
|
||
# We currently only handle setting __tablename__ as a class attribute, and not through a property. | ||
if stmt.lvalues[0].name == "__tablename__" and isinstance(stmt.rvalue, StrExpr): | ||
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['table_name'] = stmt.rvalue.value | ||
|
||
if (isinstance(stmt.rvalue, CallExpr) and isinstance(stmt.rvalue.callee, NameExpr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would factor out everything you add on this line and below in a helper function with a docstring explaining what we do here, e.g. |
||
and stmt.rvalue.callee.fullname == COLUMN_NAME): | ||
# Save columns. The name of a column on the db side can be different from the one inside the SA model. | ||
sa_column_name = stmt.lvalues[0].name | ||
|
||
db_column_name = None # type: Optional[str] | ||
if 'name' in stmt.rvalue.arg_names: | ||
name_str_expr = stmt.rvalue.args[stmt.rvalue.arg_names.index('name')] | ||
assert isinstance(name_str_expr, StrExpr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unsafe, it will crash mypy if someone passes Please also add a test for this. |
||
db_column_name = name_str_expr.value | ||
else: | ||
if len(stmt.rvalue.args) >= 1 and isinstance(stmt.rvalue.args[0], StrExpr): | ||
db_column_name = stmt.rvalue.args[0].value | ||
|
||
ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('columns', []).append( | ||
{"sa_name": sa_column_name, "db_name": db_column_name or sa_column_name} | ||
) | ||
|
||
# Save foreign keys. | ||
for arg in stmt.rvalue.args: | ||
if (isinstance(arg, CallExpr) and isinstance(arg.callee, NameExpr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe use |
||
and arg.callee.fullname == FOREIGN_KEY_NAME and len(arg.args) >= 1): | ||
fk = arg.args[0] | ||
if isinstance(fk, StrExpr): | ||
*r, parent_table_name, parent_db_col_name = fk.value.split(".") | ||
assert len(r) <= 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this is not safe, we should never crash on bad user input. |
||
ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys', | ||
{})[sa_column_name] = { | ||
"db_name": parent_db_col_name, | ||
"table_name": parent_table_name, | ||
"schema": r[0] if r else None | ||
} | ||
elif isinstance(fk, MemberExpr) and isinstance(fk.expr, NameExpr): | ||
ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys', | ||
{})[sa_column_name] = { | ||
"sa_name": fk.name, | ||
"model_fullname": fk.expr.fullname | ||
} | ||
|
||
# Also add a selection of auto-generated attributes. | ||
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table') | ||
if sym: | ||
|
@@ -317,6 +365,55 @@ def grouping_hook(ctx: FunctionContext) -> Type: | |
return ctx.default_return_type | ||
|
||
|
||
class IncompleteModelMetadata(Exception): | ||
pass | ||
|
||
|
||
def has_foreign_keys(local_model: TypeInfo, remote_model: TypeInfo) -> bool: | ||
"""Tells if `local_model` has a fk to `remote_model`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fk -> foreign key |
||
Will raise an `IncompleteModelMetadata` if some mandatory metadata is missing. | ||
""" | ||
local_metadata = local_model.metadata.get("sqlalchemy", {}) | ||
remote_metadata = remote_model.metadata.get("sqlalchemy", {}) | ||
|
||
for fk in local_metadata.get("foreign_keys", {}).values(): | ||
if 'model_fullname' in fk and remote_model.fullname() == fk['model_fullname']: | ||
return True | ||
if 'table_name' in fk: | ||
if 'table_name' not in remote_metadata: | ||
raise IncompleteModelMetadata | ||
# TODO: handle different schemas. | ||
# It's not straightforward because schema can be specified in `__table_args__` or in metadata for example | ||
if remote_metadata['table_name'] == fk['table_name']: | ||
return True | ||
|
||
return False | ||
|
||
|
||
def is_relationship_iterable(ctx: FunctionContext, local_model: TypeInfo, remote_model: TypeInfo) -> bool: | ||
"""Tries to guess if the relationship is onetoone/onetomany/manytoone. | ||
|
||
Currently we handle the most current case, where a model relates to the other one through a relationship. | ||
We also handle cases where secondaryjoin argument is provided. | ||
We don't handle advanced usecases (foreign keys on both sides, primaryjoin, etc.). | ||
""" | ||
secondaryjoin = get_argument_by_name(ctx, 'secondaryjoin') | ||
|
||
if secondaryjoin is not None: | ||
return True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a test for this? |
||
|
||
try: | ||
can_be_many_to_one = has_foreign_keys(local_model, remote_model) | ||
can_be_one_to_many = has_foreign_keys(remote_model, local_model) | ||
|
||
if not can_be_many_to_one and can_be_one_to_many: | ||
return True | ||
except IncompleteModelMetadata: | ||
pass | ||
|
||
return False # Assume relationship is not iterable, if we weren't able to guess better. | ||
|
||
|
||
def relationship_hook(ctx: FunctionContext) -> Type: | ||
"""Support basic use cases for relationships. | ||
|
||
|
@@ -369,10 +466,18 @@ class User(Base): | |
# Something complex, stay silent for now. | ||
new_arg = AnyType(TypeOfAny.special_form) | ||
|
||
# use private api | ||
current_model = ctx.api.scope.active_class() # type: ignore # type: TypeInfo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put |
||
assert current_model is not None | ||
|
||
# TODO: handle backref relationships | ||
|
||
# We figured out, the model type. Now check if we need to wrap it in Iterable | ||
if uselist_arg: | ||
if parse_bool(uselist_arg): | ||
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) | ||
elif isinstance(new_arg, Instance) and is_relationship_iterable(ctx, current_model, new_arg.type): | ||
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) | ||
else: | ||
if has_annotation: | ||
# If there is an annotation we use it as a source of truth. | ||
|
@@ -387,10 +492,10 @@ class User(Base): | |
# We really need to add this to TypeChecker API | ||
def parse_bool(expr: Expression) -> Optional[bool]: | ||
if isinstance(expr, NameExpr): | ||
if expr.fullname == 'builtins.True': | ||
return True | ||
if expr.fullname == 'builtins.False': | ||
return False | ||
if expr.fullname == 'builtins.True': | ||
return True | ||
if expr.fullname == 'builtins.False': | ||
return False | ||
return None | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,3 +280,93 @@ class M2(M1): | |
Base = declarative_base(cls=(M1, M2)) # E: Not able to calculate MRO for declarative base | ||
reveal_type(Base) # E: Revealed type is 'Any' | ||
[out] | ||
|
||
[case testRelationshipIsGuessed] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have tests where |
||
from sqlalchemy import Column, Integer, String, ForeignKey | ||
from sqlalchemy.orm import relationship | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
Base = declarative_base() | ||
|
||
class Parent(Base): | ||
__tablename__ = 'parents' | ||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
|
||
children = relationship("Child") | ||
|
||
class Child(Base): | ||
__tablename__ = 'children' | ||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
parent_id = Column(Integer, ForeignKey(Parent.id)) | ||
|
||
parent = relationship(Parent) | ||
|
||
child: Child | ||
parent: Parent | ||
|
||
reveal_type(child.parent) # E: Revealed type is 'main.Parent*' | ||
reveal_type(parent.children) # E: Revealed type is 'typing.Iterable*[main.Child]' | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for empty lines before |
||
[out] | ||
|
||
[case testRelationshipIsGuessed2] | ||
from sqlalchemy import Column, Integer, String, ForeignKey | ||
from sqlalchemy.orm import relationship | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
Base = declarative_base() | ||
|
||
class Parent(Base): | ||
__tablename__ = 'parents' | ||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
|
||
children = relationship("Child") | ||
|
||
class Child(Base): | ||
__tablename__ = 'children' | ||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
parent_id = Column(Integer, ForeignKey("parents.id")) | ||
|
||
parent = relationship(Parent) | ||
|
||
child: Child | ||
parent: Parent | ||
|
||
reveal_type(child.parent) # E: Revealed type is 'main.Parent*' | ||
reveal_type(parent.children) # E: Revealed type is 'typing.Iterable*[main.Child]' | ||
|
||
[out] | ||
|
||
[case testRelationshipIsGuessed3] | ||
from sqlalchemy import Column, Integer, String, ForeignKey | ||
from sqlalchemy.orm import relationship | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
Base = declarative_base() | ||
|
||
class Parent(Base): | ||
__tablename__ = 'parents' | ||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
|
||
children = relationship("Child") | ||
|
||
class Child(Base): | ||
__tablename__ = 'children' | ||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
parent_id = Column(Integer, ForeignKey("other_parents.id")) | ||
|
||
parent = relationship(Parent) | ||
|
||
child: Child | ||
parent: Parent | ||
|
||
reveal_type(child.parent) # E: Revealed type is 'main.Parent*' | ||
reveal_type(parent.children) # E: Revealed type is 'main.Child*' | ||
|
||
[out] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add newline at the end of file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can also support constants like:
No need to this in this PR, but may worth leaving a TODO and/or a follow-up issue.