Skip to content

Add tests and partial fixes for issues #547, #548, #549, #550 #551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions pyrefly/lib/alt/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,22 @@ enum AttributeBase {
}

impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
/// Check if a type contains Any (either directly or in a union)
fn type_contains_any(&self, ty: &Type) -> bool {
match ty {
Type::Any(_) => true,
Type::Union(types) => types.iter().any(|t| self.type_contains_any(t)),
Type::Var(v) => {
if let Some(_guard) = self.recurser.recurse(*v) {
self.type_contains_any(&self.solver().force_var(*v))
} else {
false
}
}
_ => false,
}
}

/// Gets the possible attribute bases for a type:
/// If the type is a union, we will attempt to generate bases for each member of the union
/// If the type is a bounded type var w/ a union upper bound, we will attempt to generate 1 base for
Expand Down Expand Up @@ -498,6 +514,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
context: Option<&dyn Fn() -> ErrorContext>,
todo_ctx: &str,
) -> Type {
// If the base type contains Any, we should return Any without errors
// This handles cases like `(A | Any).attr` where Any allows any attribute
if self.type_contains_any(base) {
return Type::Any(AnyStyle::Implicit);
}

let bases = self.get_possible_attribute_bases(base);
let mut results = Vec::new();
for attr_base in bases {
Expand Down Expand Up @@ -637,6 +659,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
context: Option<&dyn Fn() -> ErrorContext>,
todo_ctx: &str,
) -> Option<Type> {
// If the base type contains Any, we should allow any attribute assignment
if self.type_contains_any(base) {
return match got {
TypeOrExpr::Expr(expr) => Some(self.expr(expr, None, errors)),
TypeOrExpr::Type(ty, _) => Some(ty.clone()),
};
}

let mut narrowed_types = Some(Vec::new());
let bases = self.get_possible_attribute_bases(base);
for attr_base in bases {
Expand Down Expand Up @@ -781,6 +811,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
context: Option<&dyn Fn() -> ErrorContext>,
todo_ctx: &str,
) {
// If the base type contains Any, we should allow any attribute deletion
if self.type_contains_any(base) {
return;
}

let bases = self.get_possible_attribute_bases(base);
for attr_base in bases {
let lookup_result = attr_base.map_or_else(
Expand Down
89 changes: 89 additions & 0 deletions pyrefly/lib/test/issue_547_dict_narrowing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::testcase;

// Test case for Issue #547: Cannot Handle None Checks Before Dictionary Access
testcase!(
test_dict_none_narrow_attribute_chain,
r#"
from typing import Optional, Dict, Any, assert_type
class ConfigManager:
def __init__(self):
self.system_context: Optional[Dict[str, Any]] = None
def test_explicit_none_check(self) -> None:
if self.system_context is not None:
# After narrowing, should be Dict[str, Any], not Optional
assert_type(self.system_context, Dict[str, Any])
# Should be able to set items without error
self.system_context["updated"] = True
self.system_context["data"] = {"key": "value"}
def test_dict_methods(self) -> Any:
if self.system_context is not None:
# Dictionary methods should work without error
value = self.system_context.get("key", "default")
keys = self.system_context.keys()
items = self.system_context.items()
return value
return None
def test_truthy_check(self) -> list[str]:
if self.system_context:
# Truthy check should also narrow
assert_type(self.system_context, Dict[str, Any])
return list(self.system_context.keys())
return []
"#,
);

testcase!(
test_dict_none_narrow_early_return,
r#"
from typing import Optional, Dict, Any, assert_type
def process_config(config: Optional[Dict[str, Any]]) -> str:
if config is None:
return "no config"
# After early return, config cannot be None
assert_type(config, Dict[str, Any])
# Should not error on dictionary methods
return config.get("setting", "default")
def process_with_isinstance(data: Optional[Dict[str, Any]]) -> None:
if isinstance(data, dict):
# After isinstance check, should narrow to dict
assert_type(data, Dict[str, Any])
data["checked"] = True
items = data.items()
for k, v in items:
print(f"{k}: {v}")
"#,
);

testcase!(
test_dict_none_narrow_nested,
r#"
from typing import Optional, Dict, Any, assert_type
class NestedConfig:
def __init__(self):
self.outer: Optional[Dict[str, Optional[Dict[str, Any]]]] = None
def test_nested_narrowing(self) -> None:
if self.outer is not None:
assert_type(self.outer, Dict[str, Optional[Dict[str, Any]]])
self.outer["key"] = {"nested": "value"}
inner = self.outer.get("key")
if inner is not None:
assert_type(inner, Dict[str, Any])
inner["updated"] = True
"#,
);
85 changes: 85 additions & 0 deletions pyrefly/lib/test/issue_548_pydantic_field.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::test::util::TestEnv;
use crate::testcase;

// Create a test environment with mocked pydantic module
fn env_with_pydantic() -> TestEnv {
let mut env = TestEnv::new();
env.add(
"pydantic",
r#"
from typing import TypeVar, overload, Any
_T = TypeVar("_T")
class BaseModel:
pass
# Field overloads based on typical pydantic signatures
@overload
def Field(default: _T, *, ge: int | None = None, le: int | None = None, description: str | None = None, min_length: int | None = None) -> _T: ...
@overload
def Field(*, default: _T, ge: int | None = None, le: int | None = None, description: str | None = None, min_length: int | None = None) -> _T: ...
@overload
def Field(*, ge: int | None = None, le: int | None = None, description: str | None = None, min_length: int | None = None) -> Any: ...
def Field(default: Any | None = None, *, ge: int | None = None, le: int | None = None, description: str | None = None, min_length: int | None = None) -> Any:
return default
"#,
);
env
}

// Test case for Issue #548: Pydantic Field Overload Resolution Failure
testcase!(
test_pydantic_field_positional,
env_with_pydantic(),
r#"
from typing import assert_type
from pydantic import BaseModel, Field
class DatabaseConfig(BaseModel):
# Positional default argument should work
port: int = Field(5432, ge=1, le=65535, description="Database port")
assert_type(port, int)
# Multiple fields with positional defaults
min_connections: int = Field(5, ge=1, le=100)
max_connections: int = Field(20, ge=1, le=1000)
# Keyword default should also work
timeout: int = Field(default=30, ge=1)
"#,
);

testcase!(
test_pydantic_field_type_inference,
env_with_pydantic(),
r#"
from typing import assert_type
from pydantic import BaseModel, Field
class Config(BaseModel):
# TypeVar inference should work with positional args
name: str = Field("default_name", min_length=1)
assert_type(name, str)
count: int = Field(42, ge=0)
assert_type(count, int)
flag: bool = Field(True, description="Feature flag")
assert_type(flag, bool)
# None as default
optional_value: int | None = Field(None, description="Optional")
assert_type(optional_value, int | None)
"#,
);
48 changes: 48 additions & 0 deletions pyrefly/lib/test/issue_549_any_union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::testcase;

testcase!(
test_any_union_attribute_access,
r#"
from typing import Any
class A:
pass
def foo(x: A | Any) -> None:
x.bar # Should not error, Any allows any attribute
"#,
);

testcase!(
test_any_union_method_call,
r#"
from typing import Any
class A:
pass
def foo(x: A | Any) -> None:
x.bar() # Should not error, Any allows any method call
"#,
);

testcase!(
test_non_any_union_attribute_access,
r#"
class A:
pass
class B:
pass
def foo(x: A | B) -> None:
x.bar # E: Object of class `A` has no attribute `bar` # E: Object of class `B` has no attribute `bar`
"#,
);
97 changes: 97 additions & 0 deletions pyrefly/lib/test/issue_550_init_helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::testcase;

// Test case for Issue #550: Support initialization helper methods called from __init__
testcase!(
test_init_helper_methods,
r#"
class MetricsCollector:
def __init__(self):
# Call helper methods during initialization
self._init_performance_metrics()
self._init_error_metrics()
def _init_performance_metrics(self) -> None:
# These should not error when called from __init__
self.request_duration = 0.0
self.db_query_duration = 0.0
def _init_error_metrics(self) -> None:
# These should not error when called from __init__
self.error_count = 0
self.error_rate = 0.0
def get_metrics(self) -> dict[str, float]:
# These attributes should be recognized as defined
return {
"request_duration": self.request_duration,
"error_count": self.error_count,
}
"#,
);

testcase!(
test_conditional_init_helpers,
r#"
class ConfigurableService:
def __init__(self, enable_cache: bool = True):
self.enabled = True
if enable_cache:
self._setup_cache()
def _setup_cache(self) -> None:
# Attributes defined in conditionally-called helpers
self.cache_size = 1000
self.cache = {}
def use_cache(self) -> None:
# Should understand cache might not be defined
if hasattr(self, "cache"):
self.cache.clear()
"#,
);

testcase!(
test_nested_init_helpers,
r#"
class ComplexService:
def __init__(self):
self._init_base()
def _init_base(self) -> None:
self.base_attr = "base"
self._init_extended()
def _init_extended(self) -> None:
# Nested helper called from another helper
self.extended_attr = "extended"
def get_attrs(self) -> tuple[str, str]:
return (self.base_attr, self.extended_attr)
"#,
);

testcase!(
bug = "Should still error when helper is not called from init",
test_uncalled_helper_method,
r#"
class Service:
def __init__(self):
self.initialized = True
# Note: _setup is NOT called from __init__
def _setup(self) -> None:
# This SHOULD error - not called from __init__
self.uncalled_attr = "error" # E: Attribute `uncalled_attr` is implicitly defined
def use_attr(self) -> None:
# This should error - uncalled_attr might not exist
print(self.uncalled_attr) # E: Object of class `Service` has no attribute `uncalled_attr`
"#,
);
4 changes: 4 additions & 0 deletions pyrefly/lib/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ mod flow;
mod generic_basic;
mod generic_restrictions;
mod imports;
mod issue_547_dict_narrowing;
mod issue_548_pydantic_field;
mod issue_549_any_union;
mod issue_550_init_helpers;
mod literal;
mod lsp;
mod metadata;
Expand Down
Loading