From 321f9c4e4bdef5722a3a6e710518e6347db559dd Mon Sep 17 00:00:00 2001 From: TomerStarkware Date: Mon, 17 Feb 2025 12:49:42 +0200 Subject: [PATCH] added caching to trait functions and generated impl's functions --- crates/cairo-lang-lowering/src/cache/mod.rs | 328 ++++++++++++++++++-- 1 file changed, 294 insertions(+), 34 deletions(-) diff --git a/crates/cairo-lang-lowering/src/cache/mod.rs b/crates/cairo-lang-lowering/src/cache/mod.rs index 8778489b872..2b215cc2a81 100644 --- a/crates/cairo-lang-lowering/src/cache/mod.rs +++ b/crates/cairo-lang-lowering/src/cache/mod.rs @@ -14,8 +14,8 @@ use cairo_lang_defs::ids::{ ImplFunctionLongId, LanguageElementId, LocalVarId, LocalVarLongId, MemberLongId, ModuleFileId, ModuleId, ParamLongId, PluginGeneratedFileId, PluginGeneratedFileLongId, StatementConstLongId, StatementItemId, StatementUseLongId, StructLongId, SubmoduleId, SubmoduleLongId, - TraitConstantId, TraitConstantLongId, TraitFunctionLongId, TraitTypeId, TraitTypeLongId, - VariantLongId, + TraitConstantId, TraitConstantLongId, TraitFunctionLongId, TraitId, TraitLongId, TraitTypeId, + TraitTypeLongId, VariantLongId, }; use cairo_lang_diagnostics::{Maybe, skip_diagnostic}; use cairo_lang_filesystem::ids::{ @@ -29,19 +29,25 @@ use cairo_lang_semantic::items::functions::{ ConcreteFunctionWithBody, GenericFunctionId, GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId, ImplGenericFunctionWithBodyId, }; -use cairo_lang_semantic::items::imp::{ImplId, ImplLongId}; +use cairo_lang_semantic::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType}; +use cairo_lang_semantic::items::imp::{ + GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplLongId, +}; +use cairo_lang_semantic::items::trt::ConcreteTraitGenericFunctionLongId; use cairo_lang_semantic::types::{ - ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId, ImplTypeId, + ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId, + ImplTypeId, }; use cairo_lang_semantic::{ - ConcreteFunction, ConcreteImplLongId, MatchArmSelector, TypeId, TypeLongId, ValueSelectorArm, + ConcreteFunction, ConcreteImplLongId, ConcreteTraitLongId, MatchArmSelector, TypeId, + TypeLongId, ValueSelectorArm, }; use cairo_lang_syntax::node::TypedStablePtr; use cairo_lang_syntax::node::ast::{ ExprPtr, FunctionWithBodyPtr, GenericParamPtr, ItemConstantPtr, ItemEnumPtr, - ItemExternFunctionPtr, ItemExternTypePtr, ItemImplPtr, ItemModulePtr, ItemStructPtr, MemberPtr, - ParamPtr, TerminalIdentifierPtr, TraitItemConstantPtr, TraitItemFunctionPtr, TraitItemTypePtr, - UsePathLeafPtr, VariantPtr, + ItemExternFunctionPtr, ItemExternTypePtr, ItemImplPtr, ItemModulePtr, ItemStructPtr, + ItemTraitPtr, MemberPtr, ParamPtr, TerminalIdentifierPtr, TraitItemConstantPtr, + TraitItemFunctionPtr, TraitItemTypePtr, UsePathLeafPtr, VariantPtr, }; use cairo_lang_syntax::node::green::{GreenNode, GreenNodeDetails}; use cairo_lang_syntax::node::ids::{GreenId, SyntaxStablePtrId}; @@ -103,7 +109,7 @@ pub fn load_cached_crate_functions( ) } -/// Cache the lowering of each function in the crate into a blob. +/// Cache the lowering of a crate. pub fn generate_crate_cache( db: &dyn LoweringGroup, crate_id: cairo_lang_filesystem::ids::CrateId, @@ -119,17 +125,27 @@ pub fn generate_crate_cache( function_ids.push(FunctionWithBodyId::Impl(*impl_func)); } } + for trait_id in db.module_traits_ids(*module_id)?.iter() { + for trait_func in db.trait_functions(*trait_id)?.values() { + function_ids.push(FunctionWithBodyId::Trait(*trait_func)); + } + } } let mut ctx = CacheSavingContext::new(db, crate_id); let cached = function_ids .iter() - .map(|id| { - let multi = db.priv_function_with_body_multi_lowering(*id)?; - Ok(( + .filter_map(|id| { + db.function_body(*id).ok()?; + let multi = match db.priv_function_with_body_multi_lowering(*id) { + Ok(multi) => multi, + Err(err) => return Some(Err(err)), + }; + + Some(Ok(( DefsFunctionWithBodyIdCached::new(*id, &mut ctx.semantic_ctx), MultiLoweringCached::new((*multi).clone(), &mut ctx), - )) + ))) }) .collect::>>()?; @@ -1230,7 +1246,7 @@ enum ConstValueCached { Enum(ConcreteVariantCached, Box), NonZero(Box), Boxed(Box), - Generic(GenericParamCached), + Generic(GenericParamIdCached), ImplConstant(ImplConstantCached), } impl ConstValueCached { @@ -1252,7 +1268,7 @@ impl ConstValueCached { ConstValueCached::Boxed(Box::new(ConstValueCached::new(*value, ctx))) } ConstValue::Generic(generic_param) => { - ConstValueCached::Generic(GenericParamCached::new(generic_param, ctx)) + ConstValueCached::Generic(GenericParamIdCached::new(generic_param, ctx)) } ConstValue::ImplConstant(impl_constant_id) => { ConstValueCached::ImplConstant(ImplConstantCached::new(impl_constant_id, ctx)) @@ -1584,6 +1600,7 @@ impl SemanticConcreteFunctionWithBodyCached { enum GenericFunctionWithBodyCached { Free(LanguageElementCached), Impl(ConcreteImplCached, ImplFunctionBodyCached), + Trait(ConcreteTraitIdCached, LanguageElementCached), } impl GenericFunctionWithBodyCached { @@ -1599,9 +1616,10 @@ impl GenericFunctionWithBodyCached { ConcreteImplCached::new(id.concrete_impl_id, ctx), ImplFunctionBodyCached::new(id.function_body, ctx), ), - GenericFunctionWithBodyId::Trait(_id) => { - unreachable!("Trait functions are not supported in serialization") - } + GenericFunctionWithBodyId::Trait(id) => GenericFunctionWithBodyCached::Trait( + ConcreteTraitIdCached::new(id.concrete_trait(ctx.db), ctx), + LanguageElementCached::new(id.trait_function(ctx.db), ctx), + ), } } fn embed(self, ctx: &mut SemanticCacheLoadingContext<'_>) -> GenericFunctionWithBodyId { @@ -1613,12 +1631,27 @@ impl GenericFunctionWithBodyCached { GenericFunctionWithBodyId::Free(id) } GenericFunctionWithBodyCached::Impl(id, function_body) => { - // todo handle trait functions GenericFunctionWithBodyId::Impl(ImplGenericFunctionWithBodyId { concrete_impl_id: id.embed(ctx), function_body: function_body.embed(ctx), }) } + GenericFunctionWithBodyCached::Trait(id, name) => { + let concrete_trait_id = id.embed(ctx); + let (module_file_id, stable_ptr) = name.embed(ctx); + let trait_function_id = + TraitFunctionLongId(module_file_id, TraitItemFunctionPtr(stable_ptr)) + .intern(ctx.db); + + GenericFunctionWithBodyId::Trait( + ConcreteTraitGenericFunctionLongId::new( + ctx.db, + concrete_trait_id, + trait_function_id, + ) + .intern(ctx.db), + ) + } } } } @@ -1861,9 +1894,10 @@ enum TypeCached { /// during inference. Tuple(Vec), Snapshot(Box), - GenericParameter(GenericParamCached), + GenericParameter(GenericParamIdCached), ImplType(ImplTypeCached), FixedSizeArray(TypeIdCached, ConstValueCached), + Closure(ClosureTypeLongIdCached), } impl TypeCached { @@ -1879,7 +1913,7 @@ impl TypeCached { TypeCached::Snapshot(Box::new(TypeIdCached::new(type_id, ctx))) } semantic::TypeLongId::GenericParameter(generic_param_id) => { - TypeCached::GenericParameter(GenericParamCached::new(generic_param_id, ctx)) + TypeCached::GenericParameter(GenericParamIdCached::new(generic_param_id, ctx)) } semantic::TypeLongId::ImplType(impl_type_id) => { TypeCached::ImplType(ImplTypeCached::new(impl_type_id, ctx)) @@ -1888,10 +1922,10 @@ impl TypeCached { TypeIdCached::new(type_id, ctx), ConstValueCached::new(size.lookup_intern(ctx.db), ctx), ), - TypeLongId::Var(_) - | TypeLongId::Closure(_) - | TypeLongId::Missing(_) - | TypeLongId::Coupon(_) => { + semantic::TypeLongId::Closure(closure_ty) => { + TypeCached::Closure(ClosureTypeLongIdCached::new(closure_ty, ctx)) + } + TypeLongId::Var(_) | TypeLongId::Missing(_) | TypeLongId::Coupon(_) => { unreachable!( "type {:?} is not supported for caching", type_id.debug(ctx.db.elongate()) @@ -1914,6 +1948,7 @@ impl TypeCached { type_id: type_id.embed(ctx), size: size.embed(ctx).intern(ctx.db), }, + TypeCached::Closure(closure_ty) => TypeLongId::Closure(closure_ty.embed(ctx)), } } } @@ -1977,6 +2012,51 @@ impl ConcreteTypeCached { } } +#[derive(Serialize, Deserialize, Clone)] +struct ClosureTypeLongIdCached { + pub param_tys: Vec, + pub ret_ty: TypeIdCached, + pub captured_types: Vec, + pub parent_function: Option, + pub wrapper_location: SyntaxStablePtrIdCached, +} + +impl ClosureTypeLongIdCached { + fn new(closure_type_id: ClosureTypeLongId, ctx: &mut SemanticCacheSavingContext<'_>) -> Self { + Self { + param_tys: closure_type_id + .param_tys + .iter() + .map(|ty| TypeIdCached::new(*ty, ctx)) + .collect(), + ret_ty: TypeIdCached::new(closure_type_id.ret_ty, ctx), + captured_types: closure_type_id + .captured_types + .iter() + .map(|ty| TypeIdCached::new(*ty, ctx)) + .collect(), + parent_function: closure_type_id + .parent_function + .ok() + .map(|f| SemanticFunctionIdCached::new(f, ctx)), + wrapper_location: SyntaxStablePtrIdCached::new( + closure_type_id.wrapper_location.stable_ptr(), + ctx, + ), + } + } + + fn embed(self, ctx: &mut SemanticCacheLoadingContext<'_>) -> ClosureTypeLongId { + ClosureTypeLongId { + param_tys: self.param_tys.into_iter().map(|ty| ty.embed(ctx)).collect(), + ret_ty: self.ret_ty.embed(ctx), + captured_types: self.captured_types.into_iter().map(|ty| ty.embed(ctx)).collect(), + parent_function: self.parent_function.map(|f| f.embed(ctx)).ok_or_else(skip_diagnostic), + wrapper_location: StableLocation::new(self.wrapper_location.embed(ctx)), + } + } +} + #[derive(Serialize, Deserialize, Clone)] struct ImplTypeCached { impl_id: ImplIdCached, @@ -1996,7 +2076,7 @@ impl ImplTypeCached { } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Hash, PartialEq, Eq)] struct TraitTypeCached { language_element: LanguageElementCached, } @@ -2010,10 +2090,57 @@ impl TraitTypeCached { } } +#[derive(Serialize, Deserialize, Clone)] +struct TraitIdCached { + language_element: LanguageElementCached, +} +impl TraitIdCached { + fn new(trait_id: TraitId, ctx: &mut SemanticCacheSavingContext<'_>) -> Self { + Self { language_element: LanguageElementCached::new(trait_id, ctx) } + } + fn embed(self, ctx: &mut SemanticCacheLoadingContext<'_>) -> TraitId { + let (module_file_id, stable_ptr) = self.language_element.embed(ctx); + TraitLongId(module_file_id, ItemTraitPtr(stable_ptr)).intern(ctx.db) + } +} + +#[derive(Serialize, Deserialize, Clone)] +struct ConcreteTraitIdCached { + trait_id: TraitIdCached, + generic_args: Vec, +} + +impl ConcreteTraitIdCached { + fn new( + concrete_trait_id: semantic::ConcreteTraitId, + ctx: &mut SemanticCacheSavingContext<'_>, + ) -> Self { + let long_id = concrete_trait_id.lookup_intern(ctx.db); + Self { + trait_id: TraitIdCached::new(long_id.trait_id, ctx), + generic_args: long_id + .generic_args + .into_iter() + .map(|arg| GenericArgumentCached::new(arg, ctx)) + .collect(), + } + } + fn embed(self, ctx: &mut SemanticCacheLoadingContext<'_>) -> semantic::ConcreteTraitId { + let trait_id = self.trait_id.embed(ctx); + let long_id = ConcreteTraitLongId { + trait_id, + generic_args: self.generic_args.into_iter().map(|arg| arg.embed(ctx)).collect(), + }; + long_id.intern(ctx.db) + } +} + #[derive(Serialize, Deserialize, Clone)] enum ImplCached { Concrete(ConcreteImplCached), - GenericParameter(GenericParamCached), + GenericParameter(GenericParamIdCached), + GeneratedImpl(GeneratedImplCached), + SelfImpl(ConcreteTraitIdCached), } impl ImplCached { fn new(impl_id: ImplLongId, ctx: &mut SemanticCacheSavingContext<'_>) -> Self { @@ -2022,12 +2149,15 @@ impl ImplCached { ImplCached::Concrete(ConcreteImplCached::new(concrete_impl, ctx)) } ImplLongId::GenericParameter(generic_param_id) => { - ImplCached::GenericParameter(GenericParamCached::new(generic_param_id, ctx)) + ImplCached::GenericParameter(GenericParamIdCached::new(generic_param_id, ctx)) } - ImplLongId::ImplVar(_) - | ImplLongId::ImplImpl(_) - | ImplLongId::SelfImpl(_) - | ImplLongId::GeneratedImpl(_) => { + ImplLongId::GeneratedImpl(generated_impl_id) => { + ImplCached::GeneratedImpl(GeneratedImplCached::new(generated_impl_id, ctx)) + } + ImplLongId::SelfImpl(concrete_trait_id) => { + ImplCached::SelfImpl(ConcreteTraitIdCached::new(concrete_trait_id, ctx)) + } + ImplLongId::ImplVar(_) | ImplLongId::ImplImpl(_) => { unreachable!( "impl {:?} is not supported for caching", impl_id.debug(ctx.db.elongate()) @@ -2041,6 +2171,10 @@ impl ImplCached { ImplCached::GenericParameter(generic_param) => { ImplLongId::GenericParameter(generic_param.embed(ctx)) } + ImplCached::GeneratedImpl(generated_impl) => { + ImplLongId::GeneratedImpl(generated_impl.embed(ctx)) + } + ImplCached::SelfImpl(concrete_trait) => ImplLongId::SelfImpl(concrete_trait.embed(ctx)), } } } @@ -2115,10 +2249,47 @@ impl ImplDefIdCached { } #[derive(Serialize, Deserialize, Clone)] -struct GenericParamCached { +struct GeneratedImplCached { + pub concrete_trait: ConcreteTraitIdCached, + /// The generic params required for the impl. Typically impls and negative impls. + /// We save the params so that we can validate negative impls. + pub generic_params: Vec, + pub impl_items: OrderedHashMap, +} +impl GeneratedImplCached { + fn new(impl_id: GeneratedImplId, ctx: &mut SemanticCacheSavingContext<'_>) -> Self { + let long_id = impl_id.lookup_intern(ctx.db); + Self { + concrete_trait: ConcreteTraitIdCached::new(long_id.concrete_trait, ctx), + generic_params: long_id + .generic_params + .into_iter() + .map(|param| GenericParamCached::new(param, ctx)) + .collect(), + impl_items: long_id + .impl_items + .0 + .into_iter() + .map(|(k, v)| (TraitTypeCached::new(k, ctx), TypeIdCached::new(v, ctx))) + .collect(), + } + } + fn embed(self, ctx: &mut SemanticCacheLoadingContext<'_>) -> GeneratedImplId { + let concrete_trait = self.concrete_trait.embed(ctx); + let generic_params = + self.generic_params.into_iter().map(|param| param.embed(ctx)).collect(); + let impl_items = GeneratedImplItems( + self.impl_items.into_iter().map(|(k, v)| (k.embed(ctx), v.embed(ctx))).collect(), + ); + GeneratedImplLongId { concrete_trait, generic_params, impl_items }.intern(ctx.db) + } +} + +#[derive(Serialize, Deserialize, Clone)] +struct GenericParamIdCached { language_element: LanguageElementCached, } -impl GenericParamCached { +impl GenericParamIdCached { fn new(generic_param_id: GenericParamId, ctx: &mut SemanticCacheSavingContext<'_>) -> Self { Self { language_element: LanguageElementCached::new(generic_param_id, ctx) } } @@ -2128,6 +2299,95 @@ impl GenericParamCached { } } +#[derive(Serialize, Deserialize, Clone)] +enum GenericParamCached { + Type(GenericParamIdCached), + Const(GenericParamIdCached, TypeIdCached), + Impl( + GenericParamIdCached, + Option, + OrderedHashMap, + ), + NegImpl( + GenericParamIdCached, + Option, + OrderedHashMap, + ), +} +impl GenericParamCached { + fn new( + generic_param: semantic::GenericParam, + ctx: &mut SemanticCacheSavingContext<'_>, + ) -> Self { + match generic_param { + semantic::GenericParam::Type(generic_type) => { + GenericParamCached::Type(GenericParamIdCached::new(generic_type.id, ctx)) + } + semantic::GenericParam::Const(generic_param_const) => GenericParamCached::Const( + GenericParamIdCached::new(generic_param_const.id, ctx), + TypeIdCached::new(generic_param_const.ty, ctx), + ), + semantic::GenericParam::Impl(generic_param_impl) => GenericParamCached::Impl( + GenericParamIdCached::new(generic_param_impl.id, ctx), + generic_param_impl.concrete_trait.ok().map(|t| ConcreteTraitIdCached::new(t, ctx)), + generic_param_impl + .type_constraints + .into_iter() + .map(|(k, v)| (TraitTypeCached::new(k, ctx), TypeIdCached::new(v, ctx))) + .collect(), + ), + semantic::GenericParam::NegImpl(generic_param_neg_impl) => GenericParamCached::NegImpl( + GenericParamIdCached::new(generic_param_neg_impl.id, ctx), + generic_param_neg_impl + .concrete_trait + .ok() + .map(|t| ConcreteTraitIdCached::new(t, ctx)), + generic_param_neg_impl + .type_constraints + .into_iter() + .map(|(k, v)| (TraitTypeCached::new(k, ctx), TypeIdCached::new(v, ctx))) + .collect(), + ), + } + } + + fn embed(self, ctx: &mut SemanticCacheLoadingContext<'_>) -> semantic::GenericParam { + match self { + GenericParamCached::Type(id) => { + semantic::GenericParam::Type(GenericParamType { id: id.embed(ctx) }) + } + GenericParamCached::Const(id, ty) => semantic::GenericParam::Const(GenericParamConst { + id: id.embed(ctx), + ty: ty.embed(ctx), + }), + GenericParamCached::Impl(id, concrete_trait, type_constraints) => { + semantic::GenericParam::Impl(GenericParamImpl { + id: id.embed(ctx), + concrete_trait: concrete_trait + .map(|t| t.embed(ctx)) + .ok_or_else(skip_diagnostic), + type_constraints: type_constraints + .into_iter() + .map(|(k, v)| (k.embed(ctx), v.embed(ctx))) + .collect(), + }) + } + GenericParamCached::NegImpl(id, concrete_trait, type_constraints) => { + semantic::GenericParam::NegImpl(GenericParamImpl { + id: id.embed(ctx), + concrete_trait: concrete_trait + .map(|t| t.embed(ctx)) + .ok_or_else(skip_diagnostic), + type_constraints: type_constraints + .into_iter() + .map(|(k, v)| (k.embed(ctx), v.embed(ctx))) + .collect(), + }) + } + } + } +} + #[derive(Serialize, Deserialize, Clone)] struct ConcreteVariantCached { concrete_enum_id: ConcreteEnumCached,