Skip to content

Commit

Permalink
Added deref of members. (#5731)
Browse files Browse the repository at this point in the history
  • Loading branch information
gilbens-starkware authored Jun 6, 2024
1 parent 4f33533 commit 9592f1d
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 30 deletions.
12 changes: 10 additions & 2 deletions corelib/src/nullable.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ pub(crate) extern fn nullable_from_box<T>(value: Box<T>) -> Nullable<T> nopanic;
pub extern fn match_nullable<T>(value: Nullable<T>) -> FromNullableResult<T> nopanic;
extern fn nullable_forward_snapshot<T>(value: @Nullable<T>) -> Nullable<@T> nopanic;

#[generate_trait]
pub impl NullableImpl<T> of NullableTrait<T> {
impl NullableDeref<T> of core::ops::Deref<Nullable<T>> {
type Target = T;
fn deref(self: Nullable<T>) -> T {
match match_nullable(self) {
FromNullableResult::Null => core::panic_with_felt252('Attempted to deref null value'),
FromNullableResult::NotNull(value) => value.unbox(),
}
}
}

#[generate_trait]
pub impl NullableImpl<T> of NullableTrait<T> {
fn deref(nullable: Nullable<T>) -> T {
nullable.deref()
}

fn deref_or<+Drop<T>>(self: Nullable<T>, default: T) -> T {
match match_nullable(self) {
FromNullableResult::Null => default,
Expand Down
3 changes: 3 additions & 0 deletions corelib/src/ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ pub use index::{Index, IndexView};

mod arith;
pub use arith::{AddAssign, SubAssign, MulAssign, DivAssign, RemAssign};
mod deref;
pub use deref::Deref;

6 changes: 6 additions & 0 deletions corelib/src/ops/deref.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/// A trait for dereferencing a value. This is used in order to directly access members of the
/// dereferenced value.
pub trait Deref<T> {
type Target;
fn deref(self: T) -> Self::Target;
}
1 change: 1 addition & 0 deletions corelib/src/prelude/v2023_01.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,4 @@ use core::{zeroable, zeroable::{NonZero, Zeroable}};

#[cfg(test)]
use core::test;
pub use core::ops::Deref;
1 change: 1 addition & 0 deletions corelib/src/prelude/v2023_10.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ pub use core::traits::Index;
#[feature("deprecated-index-traits")]
pub use core::traits::IndexView;
pub use core::zeroable::NonZero;
pub use core::ops::Deref;
45 changes: 45 additions & 0 deletions corelib/src/test/deref_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[derive(Drop, Copy)]
struct S1 {
a: usize,
b: felt252
}

#[derive(Drop, Copy)]
struct S2 {
inner: S1,
a: usize,
}

#[derive(Drop, Copy)]
struct S3 {
inner: S2
}


impl S2Deref of core::ops::deref::Deref<S2> {
type Target = S1;
fn deref(self: S2) -> S1 {
self.inner
}
}

impl S3Deref of core::ops::deref::Deref<S3> {
type Target = S2;
fn deref(self: S3) -> S2 {
self.inner
}
}

#[test]
fn test_simple_deref() {
let s1 = S1 { a: 1, b: 2 };
let s2 = S2 { inner: s1, a: 3 };
let s3 = S3 { inner: s2 };
assert_eq!(s1.a, 1);
assert_eq!(s2.a, 3);
assert_eq!(s3.a, 3);
assert_eq!(s3.inner.a, 3);
assert_eq!(s3.inner.inner.a, 1);
assert_eq!(s3.b, 2);
assert_eq!(s3.inner.b, 2);
}
158 changes: 130 additions & 28 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use cairo_lang_defs::ids::{
EnumId, FunctionTitleId, GenericKind, LanguageElementId, LocalVarLongId, MemberId,
TraitFunctionId, TraitId,
};
use cairo_lang_diagnostics::{Maybe, ToOption};
use cairo_lang_diagnostics::{skip_diagnostic, Maybe, ToOption};
use cairo_lang_filesystem::ids::{FileKind, FileLongId, VirtualFile};
use cairo_lang_syntax::node::ast::{
BlockOrIf, ExprPtr, PatternListOr, PatternStructParam, UnaryOperator,
Expand Down Expand Up @@ -1375,8 +1375,8 @@ fn compute_expr_indexed_semantic(
expr,
syntax.into(),
None,
|ty, _, inference_errors| NoImplementationOfIndexOperator { ty, inference_errors },
|ty, _, _| MultipleImplementationOfIndexOperator(ty),
|ty, _, inference_errors| Some(NoImplementationOfIndexOperator { ty, inference_errors }),
|ty, _, _| Some(MultipleImplementationOfIndexOperator(ty)),
)?;

let index_expr_syntax = &syntax.index_expr(syntax_db);
Expand Down Expand Up @@ -1409,12 +1409,12 @@ fn compute_method_function_call_data(
TypeId,
SmolStr,
TraitInferenceErrors,
) -> SemanticDiagnosticKind,
) -> Option<SemanticDiagnosticKind>,
multiple_trait_diagnostic: fn(
TypeId,
TraitFunctionId,
TraitFunctionId,
) -> SemanticDiagnosticKind,
) -> Option<SemanticDiagnosticKind>,
) -> Maybe<(FunctionId, ExprAndId, Mutability)> {
let self_ty = ctx.reduce_ty(self_expr.ty());
// Inference errors found when looking for candidates. Only relevant in the case of 0 candidates
Expand All @@ -1431,21 +1431,19 @@ fn compute_method_function_call_data(
);
let trait_function_id = match candidates[..] {
[] => {
return Err(ctx.diagnostics.report(
method_syntax,
no_implementation_diagnostic(
self_ty,
func_name,
TraitInferenceErrors { traits_and_errors: inference_errors },
),
));
return Err(no_implementation_diagnostic(
self_ty,
func_name,
TraitInferenceErrors { traits_and_errors: inference_errors },
)
.map(|diag| ctx.diagnostics.report(method_syntax, diag))
.unwrap_or_else(skip_diagnostic));
}
[trait_function_id] => trait_function_id,
[trait_function_id0, trait_function_id1, ..] => {
return Err(ctx.diagnostics.report(
method_syntax,
multiple_trait_diagnostic(self_ty, trait_function_id0, trait_function_id1),
));
return Err(multiple_trait_diagnostic(self_ty, trait_function_id0, trait_function_id1)
.map(|diag| ctx.diagnostics.report(method_syntax, diag))
.unwrap_or_else(skip_diagnostic));
}
};
let (function_id, n_snapshots) =
Expand Down Expand Up @@ -2236,10 +2234,11 @@ fn method_call_expr(
lexpr,
path.stable_ptr().untyped(),
generic_args_syntax,
|ty, method_name, inference_errors| CannotCallMethod { ty, method_name, inference_errors },
|_, trait_function_id0, trait_function_id1| AmbiguousTrait {
trait_function_id0,
trait_function_id1,
|ty, method_name, inference_errors| {
Some(CannotCallMethod { ty, method_name, inference_errors })
},
|_, trait_function_id0, trait_function_id1| {
Some(AmbiguousTrait { trait_function_id0, trait_function_id1 })
},
)?;
ctx.resolver.data.resolved_items.mark_concrete(
Expand Down Expand Up @@ -2286,8 +2285,9 @@ fn member_access_expr(
TypeLongId::Concrete(concrete) => match concrete {
ConcreteTypeId::Struct(concrete_struct_id) => {
// TODO(lior): Add a diagnostic test when accessing a member of a missing type.
let members = ctx.db.concrete_struct_members(*concrete_struct_id)?;
let Some(member) = members.get(&member_name) else {
let EnrichedMembers { members, deref_functions } =
enriched_members(ctx, lexpr.clone(), stable_ptr)?;
let Some((member, n_derefs)) = members.get(&member_name) else {
return Err(ctx.diagnostics.report(
&rhs_syntax,
NoSuchMember {
Expand All @@ -2302,7 +2302,7 @@ fn member_access_expr(
rhs_syntax.stable_ptr().untyped(),
&member_name,
);
let member_path = if n_snapshots == 0 {
let member_path = if n_snapshots == 0 && *n_derefs == 0 {
lexpr.as_member_path().map(|parent| ExprVarMemberPath::Member {
parent: Box::new(parent),
member_id: member.id,
Expand All @@ -2313,12 +2313,33 @@ fn member_access_expr(
} else {
None
};
let lexpr_id = lexpr.id;
let mut derefed_expr: ExprAndId = lexpr;
for deref_function in deref_functions.iter().take(*n_derefs) {
let cur_expr = expr_function_call(
ctx,
*deref_function,
vec![NamedArg(derefed_expr, None, Mutability::Immutable)],
stable_ptr,
stable_ptr,
)
.unwrap();

let ty = wrap_in_snapshots(ctx.db, member.ty, n_snapshots);
derefed_expr =
ExprAndId { expr: cur_expr.clone(), id: ctx.exprs.alloc(cur_expr) };
}
let (_, long_ty) =
finalized_snapshot_peeled_ty(ctx, derefed_expr.ty(), &rhs_syntax)?;
let derefed_expr_concrete_struct_id = match long_ty {
TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) => {
concrete_struct_id
}
_ => unreachable!(),
};
let ty = member.ty;
let ty = wrap_in_snapshots(ctx.db, ty, n_snapshots);
Ok(Expr::MemberAccess(ExprMemberAccess {
expr: lexpr_id,
concrete_struct_id: *concrete_struct_id,
expr: derefed_expr.id,
concrete_struct_id: derefed_expr_concrete_struct_id,
member: member.id,
ty,
member_path,
Expand Down Expand Up @@ -2360,6 +2381,87 @@ fn member_access_expr(
}
}

/// The result of enriched_members lookup.
struct EnrichedMembers {
/// A map from member names to their semantic representation and the number of deref operations
/// needed to access them.
members: OrderedHashMap<SmolStr, (semantic::Member, usize)>,
/// The sequence of deref functions needed to access the members.
deref_functions: Vec<FunctionId>,
}

/// Enriched members include both direct members (in case of a struct), and members of derefed types
/// if the type implements the Deref trait into a struct. Returns a map from member names to the
/// semantic representation, and the number of deref operations needed for each member.
fn enriched_members(
ctx: &mut ComputationContext<'_>,
mut expr: ExprAndId,
stable_ptr: ast::ExprPtr,
) -> Maybe<EnrichedMembers> {
// TODO(Gil): Use this function for LS completions.
let mut ty = expr.ty();
let mut res = OrderedHashMap::default();
let mut deref_functions = vec![];
// Add direct members.
let (_, mut long_ty) = peel_snapshots(ctx.db, ty);
if matches!(long_ty, TypeLongId::Var(_)) {
// Save some work. ignore the result. The error, if any, will be reported later.
ctx.resolver.inference().solve().ok();
long_ty = ctx.resolver.inference().rewrite(long_ty).no_err();
}
let (_, long_ty) = peel_snapshots_ex(ctx.db, long_ty);

if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
let members = ctx.db.concrete_struct_members(concrete_struct_id)?;
for (member_name, member) in members.iter() {
res.insert(member_name.clone(), (member.clone(), 0));
}
}
// Add members of derefed types.
let mut n_deref = 0;
let deref_trait = get_core_trait(ctx.db, CoreTraitContext::Ops, "Deref".into());

while let Ok((function_id, cur_expr, mutability)) = compute_method_function_call_data(
ctx,
&[deref_trait],
"deref".into(),
expr,
stable_ptr.0,
None,
|_, _, _| None,
|_, _, _| None,
) {
n_deref += 1;
expr = cur_expr;
let derefed_expr = expr_function_call(
ctx,
function_id,
vec![NamedArg(expr, None, mutability)],
stable_ptr,
stable_ptr,
)?;
ty = derefed_expr.ty();
ty = ctx.reduce_ty(ty);
let (_, long_ty) = finalized_snapshot_peeled_ty(ctx, ty, stable_ptr)?;
// If the type is still a variable we stop looking for derefed members.
if let TypeLongId::Var(_) = long_ty {
break;
}
expr = ExprAndId { expr: derefed_expr.clone(), id: ctx.exprs.alloc(derefed_expr) };
if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
let members = ctx.db.concrete_struct_members(concrete_struct_id)?;
for (member_name, member) in members.iter() {
// Insert member if there is not already a member with the same name.
if res.get(&member_name.clone()).is_none() {
res.insert(member_name.clone(), (member.clone(), n_deref));
}
}
}
deref_functions.push(function_id);
}
Ok(EnrichedMembers { members: res, deref_functions })
}

/// Peels snapshots from a type and making sure it is fully not a variable type.
fn finalized_snapshot_peeled_ty(
ctx: &mut ComputationContext<'_>,
Expand Down

0 comments on commit 9592f1d

Please sign in to comment.