Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more metadata and try to guess if relationship is iterable #74

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 111 additions & 6 deletions sqlmypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
Plugin, FunctionContext, ClassDefContext, DynamicClassDefContext,
SemanticAnalyzerPluginInterface
)
from mypy.plugins.common import add_method
from mypy.plugins.common import add_method, _get_argument
from mypy.nodes import (
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF,
Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr
Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr, AssignmentStmt, CallExpr, MemberExpr
)
from mypy.types import (
UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType
Expand All @@ -25,6 +25,7 @@
COLUMN_ELEMENT_NAME = 'sqlalchemy.sql.elements.ColumnElement' # type: Final
GROUPING_NAME = 'sqlalchemy.sql.elements.Grouping' # type: Final
RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final
FOREIGN_KEY_NAME = 'sqlalchemy.sql.schema.ForeignKey' # type: Final


def is_declarative(info: TypeInfo) -> bool:
Expand Down Expand Up @@ -110,6 +111,53 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
add_method(ctx, '__init__', [kw_arg], NoneTyp())
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True

for stmt in ctx.cls.defs.body:
if not (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1 and isinstance(stmt.lvalues[0], NameExpr)):
continue

# We currently only handle setting __tablename__ as a class attribute, and not through a property.
if stmt.lvalues[0].name == "__tablename__" and isinstance(stmt.rvalue, StrExpr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also support constants like:

NAME: Final = 'users'

class User(Base):
    __tablename__ = NAME
    ...

No need to this in this PR, but may worth leaving a TODO and/or a follow-up issue.

ctx.cls.info.metadata.setdefault('sqlalchemy', {})['table_name'] = stmt.rvalue.value

if (isinstance(stmt.rvalue, CallExpr) and isinstance(stmt.rvalue.callee, NameExpr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would factor out everything you add on this line and below in a helper function with a docstring explaining what we do here, e.g. process_field_assignment().

and stmt.rvalue.callee.fullname == COLUMN_NAME):
# Save columns. The name of a column on the db side can be different from the one inside the SA model.
sa_column_name = stmt.lvalues[0].name

db_column_name = None # type: Optional[str]
if 'name' in stmt.rvalue.arg_names:
name_str_expr = stmt.rvalue.args[stmt.rvalue.arg_names.index('name')]
assert isinstance(name_str_expr, StrExpr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unsafe, it will crash mypy if someone passes name= as a variable or function call.

Please also add a test for this.

db_column_name = name_str_expr.value
else:
if len(stmt.rvalue.args) >= 1 and isinstance(stmt.rvalue.args[0], StrExpr):
db_column_name = stmt.rvalue.args[0].value

ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('columns', []).append(
{"sa_name": sa_column_name, "db_name": db_column_name or sa_column_name}
)

# Save foreign keys.
for arg in stmt.rvalue.args:
if (isinstance(arg, CallExpr) and isinstance(arg.callee, NameExpr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use RefExpr instead of NameExpr?

and arg.callee.fullname == FOREIGN_KEY_NAME and len(arg.args) >= 1):
fk = arg.args[0]
if isinstance(fk, StrExpr):
*r, parent_table_name, parent_db_col_name = fk.value.split(".")
assert len(r) <= 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is not safe, we should never crash on bad user input.

ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys',
{})[sa_column_name] = {
"db_name": parent_db_col_name,
"table_name": parent_table_name,
"schema": r[0] if r else None
}
elif isinstance(fk, MemberExpr) and isinstance(fk.expr, NameExpr):
ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys',
{})[sa_column_name] = {
"sa_name": fk.name,
"model_fullname": fk.expr.fullname
}

# Also add a selection of auto-generated attributes.
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table')
if sym:
Expand Down Expand Up @@ -317,6 +365,55 @@ def grouping_hook(ctx: FunctionContext) -> Type:
return ctx.default_return_type


class IncompleteModelMetadata(Exception):
pass


def has_foreign_keys(local_model: TypeInfo, remote_model: TypeInfo) -> bool:
"""Tells if `local_model` has a fk to `remote_model`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fk -> foreign key

Will raise an `IncompleteModelMetadata` if some mandatory metadata is missing.
"""
local_metadata = local_model.metadata.get("sqlalchemy", {})
remote_metadata = remote_model.metadata.get("sqlalchemy", {})

for fk in local_metadata.get("foreign_keys", {}).values():
if 'model_fullname' in fk and remote_model.fullname() == fk['model_fullname']:
return True
if 'table_name' in fk:
if 'table_name' not in remote_metadata:
raise IncompleteModelMetadata
# TODO: handle different schemas.
# It's not straightforward because schema can be specified in `__table_args__` or in metadata for example
if remote_metadata['table_name'] == fk['table_name']:
return True

return False


def is_relationship_iterable(ctx: FunctionContext, local_model: TypeInfo, remote_model: TypeInfo) -> bool:
"""Tries to guess if the relationship is onetoone/onetomany/manytoone.

Currently we handle the most current case, where a model relates to the other one through a relationship.
We also handle cases where secondaryjoin argument is provided.
We don't handle advanced usecases (foreign keys on both sides, primaryjoin, etc.).
"""
secondaryjoin = get_argument_by_name(ctx, 'secondaryjoin')

if secondaryjoin is not None:
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a test for this?


try:
can_be_many_to_one = has_foreign_keys(local_model, remote_model)
can_be_one_to_many = has_foreign_keys(remote_model, local_model)

if not can_be_many_to_one and can_be_one_to_many:
return True
except IncompleteModelMetadata:
pass

return False # Assume relationship is not iterable, if we weren't able to guess better.


def relationship_hook(ctx: FunctionContext) -> Type:
"""Support basic use cases for relationships.

Expand Down Expand Up @@ -369,10 +466,18 @@ class User(Base):
# Something complex, stay silent for now.
new_arg = AnyType(TypeOfAny.special_form)

# use private api
current_model = ctx.api.scope.active_class() # type: ignore # type: TypeInfo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put # type: ignore at the end, after type comment, otherwise the type comment will be ignored.

assert current_model is not None

# TODO: handle backref relationships

# We figured out, the model type. Now check if we need to wrap it in Iterable
if uselist_arg:
if parse_bool(uselist_arg):
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
elif isinstance(new_arg, Instance) and is_relationship_iterable(ctx, current_model, new_arg.type):
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
else:
if has_annotation:
# If there is an annotation we use it as a source of truth.
Expand All @@ -387,10 +492,10 @@ class User(Base):
# We really need to add this to TypeChecker API
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None


Expand Down
90 changes: 90 additions & 0 deletions test/test-data/sqlalchemy-plugin-features.test
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,93 @@ class M2(M1):
Base = declarative_base(cls=(M1, M2)) # E: Not able to calculate MRO for declarative base
reveal_type(Base) # E: Revealed type is 'Any'
[out]

[case testRelationshipIsGuessed]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have tests where schema is not None?

from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class Parent(Base):
__tablename__ = 'parents'
id = Column(Integer, primary_key=True)
name = Column(String)

children = relationship("Child")

class Child(Base):
__tablename__ = 'children'
id = Column(Integer, primary_key=True)
name = Column(String)
parent_id = Column(Integer, ForeignKey(Parent.id))

parent = relationship(Parent)

child: Child
parent: Parent

reveal_type(child.parent) # E: Revealed type is 'main.Parent*'
reveal_type(parent.children) # E: Revealed type is 'typing.Iterable*[main.Child]'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for empty lines before [out].

[out]

[case testRelationshipIsGuessed2]
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class Parent(Base):
__tablename__ = 'parents'
id = Column(Integer, primary_key=True)
name = Column(String)

children = relationship("Child")

class Child(Base):
__tablename__ = 'children'
id = Column(Integer, primary_key=True)
name = Column(String)
parent_id = Column(Integer, ForeignKey("parents.id"))

parent = relationship(Parent)

child: Child
parent: Parent

reveal_type(child.parent) # E: Revealed type is 'main.Parent*'
reveal_type(parent.children) # E: Revealed type is 'typing.Iterable*[main.Child]'

[out]

[case testRelationshipIsGuessed3]
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class Parent(Base):
__tablename__ = 'parents'
id = Column(Integer, primary_key=True)
name = Column(String)

children = relationship("Child")

class Child(Base):
__tablename__ = 'children'
id = Column(Integer, primary_key=True)
name = Column(String)
parent_id = Column(Integer, ForeignKey("other_parents.id"))

parent = relationship(Parent)

child: Child
parent: Parent

reveal_type(child.parent) # E: Revealed type is 'main.Parent*'
reveal_type(parent.children) # E: Revealed type is 'main.Child*'

[out]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add newline at the end of file.