From 91a761bb3a8d783942491a81b4e3a8eda3cccba5 Mon Sep 17 00:00:00 2001
From: Eagle941 <8973725+Eagle941@users.noreply.github.com>
Date: Mon, 23 Dec 2024 22:13:12 +0000
Subject: [PATCH] Fixed  to prevent stack overflow on self-referencing types

---
 starknet-replay/src/profiler/analysis.rs      |  29 ++---
 starknet-replay/src/profiler/mod.rs           |   7 +-
 starknet-replay/src/profiler/replace_ids.rs   | 102 ++++++++++++++----
 .../src/runner/replay_class_hash.rs           |   2 +-
 starknet-replay/src/storage/rpc/mod.rs        |   4 +-
 .../src/storage/rpc/visited_pcs.rs            |  14 +--
 6 files changed, 109 insertions(+), 49 deletions(-)

diff --git a/starknet-replay/src/profiler/analysis.rs b/starknet-replay/src/profiler/analysis.rs
index 0fd5568..549acdb 100644
--- a/starknet-replay/src/profiler/analysis.rs
+++ b/starknet-replay/src/profiler/analysis.rs
@@ -89,9 +89,10 @@ pub fn extract_libfuncs_weight(
     visited_pcs: &VisitedPcs,
     storage: &impl Storage,
 ) -> Result<ReplayStatistics, ProfilerError> {
-    let mut local_cumulative_libfuncs_weight: ReplayStatistics = ReplayStatistics::new();
+    let mut local_cumulative_libfuncs_weight = ReplayStatistics::new();
 
     for (replay_class_hash, all_pcs) in visited_pcs {
+        tracing::info!("Processing pcs from {replay_class_hash:?}.");
         let Ok(contract_class) = storage.get_contract_class_at_block(replay_class_hash) else {
             continue;
         };
@@ -100,13 +101,12 @@ pub fn extract_libfuncs_weight(
             continue;
         };
 
-        let runner = SierraProfiler::new(sierra_program.clone())?;
+        let runner = SierraProfiler::new(sierra_program)?;
 
-        for pcs in all_pcs {
-            let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, pcs);
-            local_cumulative_libfuncs_weight =
-                local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights);
-        }
+        let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, all_pcs);
+
+        local_cumulative_libfuncs_weight =
+            local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights);
     }
 
     for (concrete_name, weight) in local_cumulative_libfuncs_weight
@@ -191,14 +191,14 @@ mod tests {
         }
     }
 
-    fn compile_cairo_program(filename: &str) -> Program {
+    fn compile_cairo_program(filename: &str, replace_ids: bool) -> Program {
         let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap();
         let filename = [absolute_path.as_str(), filename].iter().join("");
         let file_path = Path::new(&filename);
         compile_cairo_project_at_path(
             file_path,
             CompilerConfig {
-                replace_ids: true,
+                replace_ids,
                 ..CompilerConfig::default()
             },
         )
@@ -207,12 +207,13 @@ mod tests {
 
     fn compile_cairo_contract(
         filename: &str,
+        replace_ids: bool,
     ) -> cairo_lang_starknet_classes::contract_class::ContractClass {
         let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap();
         let filename = [absolute_path.as_str(), filename].iter().join("");
         let file_path = Path::new(&filename);
         let config = CompilerConfig {
-            replace_ids: true,
+            replace_ids,
             ..CompilerConfig::default()
         };
         let contract_path = None;
@@ -241,7 +242,7 @@ mod tests {
         entrypoint_offset: usize,
         args: &[MaybeRelocatable],
     ) -> Vec<usize> {
-        let contract_class = compile_cairo_contract(filename);
+        let contract_class = compile_cairo_contract(filename, true);
 
         let add_pythonic_hints = false;
         let max_bytecode_size = 180_000;
@@ -388,7 +389,7 @@ mod tests {
         let visited_pcs: Vec<usize> = vec![1, 4, 6, 8, 3];
 
         let cairo_file = "/test_data/sierra_add_program.cairo";
-        let sierra_program = compile_cairo_program(cairo_file);
+        let sierra_program = compile_cairo_program(cairo_file, true);
 
         let sierra_profiler = SierraProfiler::new(sierra_program.clone()).unwrap();
 
@@ -438,7 +439,7 @@ mod tests {
         // }
 
         let cairo_file = "/test_data/sierra_add_contract.cairo";
-        let sierra_program = compile_cairo_contract(cairo_file)
+        let sierra_program = compile_cairo_contract(cairo_file, true)
             .extract_sierra_program()
             .unwrap();
         let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]);
@@ -524,7 +525,7 @@ mod tests {
         // }
 
         let cairo_file = "/test_data/sierra_dict.cairo";
-        let sierra_program = compile_cairo_contract(cairo_file)
+        let sierra_program = compile_cairo_contract(cairo_file, true)
             .extract_sierra_program()
             .unwrap();
         let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]);
diff --git a/starknet-replay/src/profiler/mod.rs b/starknet-replay/src/profiler/mod.rs
index 1e4a9cf..231ecea 100644
--- a/starknet-replay/src/profiler/mod.rs
+++ b/starknet-replay/src/profiler/mod.rs
@@ -20,6 +20,7 @@ use cairo_lang_sierra_to_casm::metadata::{
     MetadataComputationConfig,
     MetadataError,
 };
+use itertools::Itertools;
 use tracing::trace;
 
 use crate::error::ProfilerError;
@@ -190,12 +191,12 @@ impl SierraProfiler {
     #[must_use]
     pub fn collect_profiling_info(&self, pcs: &[usize]) -> HashMap<StatementIdx, usize> {
         let mut sierra_statement_weights = HashMap::default();
-        for pc in pcs {
+        for (pc, frequency) in pcs.iter().counts() {
             let statements: Vec<&CompiledStatement> =
                 self.commands.iter().filter(|c| c.pc == *pc).collect();
             for statement in statements {
                 let statement_idx = StatementIdx(statement.statement_idx);
-                *sierra_statement_weights.entry(statement_idx).or_insert(0) += 1;
+                *sierra_statement_weights.entry(statement_idx).or_insert(0) += frequency;
             }
         }
 
@@ -225,7 +226,7 @@ impl SierraProfiler {
         &self,
         statements: &HashMap<StatementIdx, usize>,
     ) -> HashMap<String, usize> {
-        let mut libfunc_weights = HashMap::<String, usize>::default();
+        let mut libfunc_weights = HashMap::default();
         for (statement_idx, frequency) in statements {
             if let Some(GenStatement::Invocation(invocation)) =
                 self.statement_idx_to_gen_statement(*statement_idx)
diff --git a/starknet-replay/src/profiler/replace_ids.rs b/starknet-replay/src/profiler/replace_ids.rs
index 115f850..be75714 100644
--- a/starknet-replay/src/profiler/replace_ids.rs
+++ b/starknet-replay/src/profiler/replace_ids.rs
@@ -4,10 +4,11 @@
 //! data. Without debug information, the [`cairo_lang_sierra::program::Program`]
 //! contains only numeric ids to indicate libfuncs and types.
 
+use std::collections::HashSet;
 use std::sync::Arc;
 
 use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId};
-use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program};
+use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program, TypeDeclaration};
 use cairo_lang_sierra_generator::db::SierraGeneratorTypeLongId;
 use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
 use cairo_lang_utils::extract_matches;
@@ -45,11 +46,11 @@ use cairo_lang_utils::extract_matches;
 /// [`cairo_lang_sierra_generator::replace_ids::SierraIdReplacer`] to be able to
 /// perform the replacement from id to string.
 #[derive(Debug, Clone, Eq, PartialEq)]
-pub struct DebugReplacer {
+pub struct DebugReplacer<'a> {
     /// The Sierra program to replace ids from.
-    program: Program,
+    program: &'a Program,
 }
-impl DebugReplacer {
+impl DebugReplacer<'_> {
     /// Get the long debug name for the libfunc with id equivalent to `id`.
     fn lookup_intern_concrete_lib_func(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncLongId {
         self.program
@@ -61,19 +62,85 @@ impl DebugReplacer {
             .long_id
     }
 
-    /// Get the long debug name for the type with id equivalent to `id`.
-    fn lookup_intern_concrete_type(&self, id: &ConcreteTypeId) -> SierraGeneratorTypeLongId {
-        let concrete_type = self
-            .program
+    /// Get the type declaration for a given `type_id`.
+    fn get_type_declaration(&self, type_id: &ConcreteTypeId) -> TypeDeclaration {
+        self.program
             .type_declarations
             .iter()
-            .find(|f| f.id.id == id.id)
+            .find(|f| f.id.id == type_id.id)
             .expect("ConcreteTypeId should be found in type_declarations.")
-            .clone();
-        SierraGeneratorTypeLongId::Regular(Arc::new(concrete_type.long_id))
+            .clone()
+    }
+
+    /// This function builds the HashSet of type dependencies for `type_id`. The
+    /// argument `visited_types` is used to keep track of previously visited
+    /// dependencies to break cycles and avoid infinite recursion.
+    fn type_dependencies(
+        &self,
+        visited_types: &mut HashSet<ConcreteTypeId>,
+        type_id: &ConcreteTypeId,
+    ) -> HashSet<ConcreteTypeId> {
+        let mut dependencies = HashSet::new();
+
+        if visited_types.contains(type_id) {
+            return dependencies;
+        }
+        visited_types.insert(type_id.clone());
+
+        let concrete_type = self.get_type_declaration(type_id);
+
+        concrete_type
+            .long_id
+            .generic_args
+            .iter()
+            .for_each(|t| match t {
+                program::GenericArg::Type(concrete_type_id) => {
+                    dependencies.insert(concrete_type_id.clone());
+                    if visited_types.contains(concrete_type_id) {
+                        return;
+                    }
+                    dependencies.extend(self.type_dependencies(visited_types, concrete_type_id));
+                    return;
+                }
+                _ => return,
+            });
+
+        dependencies
+    }
+
+    /// Returns true if `type_id` depends on `needle`. False otherwise.
+    fn has_in_deps(&self, type_id: &ConcreteTypeId, needle: &ConcreteTypeId) -> bool {
+        let mut visited_types = HashSet::new();
+        let deps = self.type_dependencies(&mut visited_types, type_id);
+        if deps.contains(&needle) {
+            return true;
+        }
+        return false;
+    }
+
+    /// Get the long debug name for the type with id equivalent to `id`.
+    ///
+    /// If `id` is a self-referencing type (i.e. it depends on itself), then the
+    /// function returns `None` as an alternative to
+    /// [`SierraGeneratorTypeLongId::CircuitBreaker`]. It's not possible to
+    /// construct a [`SierraGeneratorTypeLongId::CircuitBreaker`] object because
+    /// it requires having access to the SalsaDB of the program.
+    fn lookup_intern_concrete_type(
+        &self,
+        id: &ConcreteTypeId,
+    ) -> Option<SierraGeneratorTypeLongId> {
+        let concrete_type = self.get_type_declaration(id);
+        if self.has_in_deps(id, id) {
+            None
+        } else {
+            Some(SierraGeneratorTypeLongId::Regular(Arc::new(
+                concrete_type.long_id,
+            )))
+        }
     }
 }
-impl SierraIdReplacer for DebugReplacer {
+
+impl SierraIdReplacer for DebugReplacer<'_> {
     fn replace_libfunc_id(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncId {
         let mut long_id = self.lookup_intern_concrete_lib_func(id);
         self.replace_generic_args(&mut long_id.generic_args);
@@ -91,10 +158,7 @@ impl SierraIdReplacer for DebugReplacer {
             // It's not possible to recover the `debug_name` of `Phantom` and `CycleBreaker` because
             // it relies on access to the Salsa db which is available only during
             // contract compilation.
-            SierraGeneratorTypeLongId::Phantom(_) | SierraGeneratorTypeLongId::CycleBreaker(_) => {
-                id.clone()
-            }
-            SierraGeneratorTypeLongId::Regular(long_id) => {
+            Some(SierraGeneratorTypeLongId::Regular(long_id)) => {
                 let mut long_id = long_id.as_ref().clone();
                 self.replace_generic_args(&mut long_id.generic_args);
                 if long_id.generic_id == "Enum".into() || long_id.generic_id == "Struct".into() {
@@ -116,6 +180,7 @@ impl SierraIdReplacer for DebugReplacer {
                     debug_name: Some(long_id.to_string().into()),
                 }
             }
+            _ => id.clone(),
         }
     }
 
@@ -149,10 +214,7 @@ impl SierraIdReplacer for DebugReplacer {
 /// [`cairo_lang_sierra_generator::db::SierraGenGroup`] trait object.
 #[must_use]
 pub fn replace_sierra_ids_in_program(program: &Program) -> Program {
-    DebugReplacer {
-        program: program.clone(),
-    }
-    .apply(program)
+    DebugReplacer { program }.apply(program)
 }
 
 #[cfg(test)]
diff --git a/starknet-replay/src/runner/replay_class_hash.rs b/starknet-replay/src/runner/replay_class_hash.rs
index 63824f8..ed022e6 100644
--- a/starknet-replay/src/runner/replay_class_hash.rs
+++ b/starknet-replay/src/runner/replay_class_hash.rs
@@ -24,7 +24,7 @@ pub struct ReplayClassHash {
 
 /// The type [`VisitedPcs`] is a hashmap to store the visited program counters
 /// for each contract invocation during replay.
-pub type VisitedPcs = HashMap<ReplayClassHash, Vec<Vec<usize>>>;
+pub type VisitedPcs = HashMap<ReplayClassHash, Vec<usize>>;
 
 /// The type [`TransactionOutput`] contains the combination of transaction
 /// receipt and list of visited program counters.
diff --git a/starknet-replay/src/storage/rpc/mod.rs b/starknet-replay/src/storage/rpc/mod.rs
index f7ee0eb..e507a8a 100644
--- a/starknet-replay/src/storage/rpc/mod.rs
+++ b/starknet-replay/src/storage/rpc/mod.rs
@@ -489,15 +489,15 @@ impl ReplayStorage for RpcStorage {
 
                     let visited_pcs: VisitedPcs = state
                         .visited_pcs
-                        .clone()
                         .0
+                        .to_owned()
                         .into_iter()
                         .map(|(class_hash, pcs)| {
                             let replay_class_hash = ReplayClassHash {
                                 block_number,
                                 class_hash,
                             };
-                            (replay_class_hash, pcs.into_iter().collect())
+                            (replay_class_hash, pcs.clone())
                         })
                         .collect();
                     if let Some(filename) = trace_out {
diff --git a/starknet-replay/src/storage/rpc/visited_pcs.rs b/starknet-replay/src/storage/rpc/visited_pcs.rs
index 010427f..b4d21ed 100644
--- a/starknet-replay/src/storage/rpc/visited_pcs.rs
+++ b/starknet-replay/src/storage/rpc/visited_pcs.rs
@@ -12,19 +12,17 @@ use starknet_api::core::ClassHash;
 
 /// The hashmap of [`VisitedPcsRaw`] is a map from a
 /// [`starknet_api::core::ClassHash`] to a vector of visited program counters.
-/// The vector returned from each call to [`starknet_api::core::ClassHash`] is
-/// added to the nested vector.
 #[derive(Clone, Debug, Default, PartialEq, Eq)]
-pub struct VisitedPcsRaw(pub HashMap<ClassHash, Vec<Vec<usize>>>);
+pub struct VisitedPcsRaw(pub HashMap<ClassHash, Vec<usize>>);
 impl VisitedPcs for VisitedPcsRaw {
-    type Pcs = Vec<Vec<usize>>;
+    type Pcs = Vec<usize>;
 
     fn new() -> Self {
         VisitedPcsRaw(HashMap::default())
     }
 
     fn insert(&mut self, class_hash: &ClassHash, pcs: &[usize]) {
-        self.0.entry(*class_hash).or_default().push(pcs.to_vec());
+        self.0.entry(*class_hash).or_default().extend(pcs.iter());
     }
 
     fn iter(&self) -> impl Iterator<Item = (&ClassHash, &Self::Pcs)> {
@@ -41,15 +39,13 @@ impl VisitedPcs for VisitedPcsRaw {
 
     fn to_set(pcs: Self::Pcs) -> HashSet<usize> {
         let mut set = HashSet::new();
-        pcs.into_iter().flatten().for_each(|p| {
+        pcs.into_iter().for_each(|p| {
             set.insert(p);
         });
         set
     }
 
     fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::Pcs) {
-        for pc in pcs {
-            state.add_visited_pcs(*class_hash, &pc);
-        }
+        state.add_visited_pcs(*class_hash, &pcs);
     }
 }