Skip to content

Commit

Permalink
Fixed to prevent stack overflow on self-referencing types
Browse files Browse the repository at this point in the history
  • Loading branch information
Eagle941 committed Dec 23, 2024
1 parent 63414c8 commit 91a761b
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 49 deletions.
29 changes: 15 additions & 14 deletions starknet-replay/src/profiler/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ pub fn extract_libfuncs_weight(
visited_pcs: &VisitedPcs,
storage: &impl Storage,
) -> Result<ReplayStatistics, ProfilerError> {
let mut local_cumulative_libfuncs_weight: ReplayStatistics = ReplayStatistics::new();
let mut local_cumulative_libfuncs_weight = ReplayStatistics::new();

for (replay_class_hash, all_pcs) in visited_pcs {
tracing::info!("Processing pcs from {replay_class_hash:?}.");
let Ok(contract_class) = storage.get_contract_class_at_block(replay_class_hash) else {
continue;
};
Expand All @@ -100,13 +101,12 @@ pub fn extract_libfuncs_weight(
continue;
};

let runner = SierraProfiler::new(sierra_program.clone())?;
let runner = SierraProfiler::new(sierra_program)?;

for pcs in all_pcs {
let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, pcs);
local_cumulative_libfuncs_weight =
local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights);
}
let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, all_pcs);

local_cumulative_libfuncs_weight =
local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights);
}

for (concrete_name, weight) in local_cumulative_libfuncs_weight
Expand Down Expand Up @@ -191,14 +191,14 @@ mod tests {
}
}

fn compile_cairo_program(filename: &str) -> Program {
fn compile_cairo_program(filename: &str, replace_ids: bool) -> Program {
let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap();
let filename = [absolute_path.as_str(), filename].iter().join("");
let file_path = Path::new(&filename);
compile_cairo_project_at_path(
file_path,
CompilerConfig {
replace_ids: true,
replace_ids,
..CompilerConfig::default()
},
)
Expand All @@ -207,12 +207,13 @@ mod tests {

fn compile_cairo_contract(
filename: &str,
replace_ids: bool,
) -> cairo_lang_starknet_classes::contract_class::ContractClass {
let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap();
let filename = [absolute_path.as_str(), filename].iter().join("");
let file_path = Path::new(&filename);
let config = CompilerConfig {
replace_ids: true,
replace_ids,
..CompilerConfig::default()
};
let contract_path = None;
Expand Down Expand Up @@ -241,7 +242,7 @@ mod tests {
entrypoint_offset: usize,
args: &[MaybeRelocatable],
) -> Vec<usize> {
let contract_class = compile_cairo_contract(filename);
let contract_class = compile_cairo_contract(filename, true);

let add_pythonic_hints = false;
let max_bytecode_size = 180_000;
Expand Down Expand Up @@ -388,7 +389,7 @@ mod tests {
let visited_pcs: Vec<usize> = vec![1, 4, 6, 8, 3];

let cairo_file = "/test_data/sierra_add_program.cairo";
let sierra_program = compile_cairo_program(cairo_file);
let sierra_program = compile_cairo_program(cairo_file, true);

let sierra_profiler = SierraProfiler::new(sierra_program.clone()).unwrap();

Expand Down Expand Up @@ -438,7 +439,7 @@ mod tests {
// }

let cairo_file = "/test_data/sierra_add_contract.cairo";
let sierra_program = compile_cairo_contract(cairo_file)
let sierra_program = compile_cairo_contract(cairo_file, true)
.extract_sierra_program()
.unwrap();
let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]);
Expand Down Expand Up @@ -524,7 +525,7 @@ mod tests {
// }

let cairo_file = "/test_data/sierra_dict.cairo";
let sierra_program = compile_cairo_contract(cairo_file)
let sierra_program = compile_cairo_contract(cairo_file, true)
.extract_sierra_program()
.unwrap();
let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]);
Expand Down
7 changes: 4 additions & 3 deletions starknet-replay/src/profiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use cairo_lang_sierra_to_casm::metadata::{
MetadataComputationConfig,
MetadataError,
};
use itertools::Itertools;
use tracing::trace;

use crate::error::ProfilerError;
Expand Down Expand Up @@ -190,12 +191,12 @@ impl SierraProfiler {
#[must_use]
pub fn collect_profiling_info(&self, pcs: &[usize]) -> HashMap<StatementIdx, usize> {
let mut sierra_statement_weights = HashMap::default();
for pc in pcs {
for (pc, frequency) in pcs.iter().counts() {
let statements: Vec<&CompiledStatement> =
self.commands.iter().filter(|c| c.pc == *pc).collect();
for statement in statements {
let statement_idx = StatementIdx(statement.statement_idx);
*sierra_statement_weights.entry(statement_idx).or_insert(0) += 1;
*sierra_statement_weights.entry(statement_idx).or_insert(0) += frequency;
}
}

Expand Down Expand Up @@ -225,7 +226,7 @@ impl SierraProfiler {
&self,
statements: &HashMap<StatementIdx, usize>,
) -> HashMap<String, usize> {
let mut libfunc_weights = HashMap::<String, usize>::default();
let mut libfunc_weights = HashMap::default();
for (statement_idx, frequency) in statements {
if let Some(GenStatement::Invocation(invocation)) =
self.statement_idx_to_gen_statement(*statement_idx)
Expand Down
102 changes: 82 additions & 20 deletions starknet-replay/src/profiler/replace_ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
//! data. Without debug information, the [`cairo_lang_sierra::program::Program`]
//! contains only numeric ids to indicate libfuncs and types.
use std::collections::HashSet;
use std::sync::Arc;

use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId};
use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program};
use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program, TypeDeclaration};
use cairo_lang_sierra_generator::db::SierraGeneratorTypeLongId;
use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
use cairo_lang_utils::extract_matches;
Expand Down Expand Up @@ -45,11 +46,11 @@ use cairo_lang_utils::extract_matches;
/// [`cairo_lang_sierra_generator::replace_ids::SierraIdReplacer`] to be able to
/// perform the replacement from id to string.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DebugReplacer {
pub struct DebugReplacer<'a> {
/// The Sierra program to replace ids from.
program: Program,
program: &'a Program,
}
impl DebugReplacer {
impl DebugReplacer<'_> {
/// Get the long debug name for the libfunc with id equivalent to `id`.
fn lookup_intern_concrete_lib_func(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncLongId {
self.program
Expand All @@ -61,19 +62,85 @@ impl DebugReplacer {
.long_id
}

/// Get the long debug name for the type with id equivalent to `id`.
fn lookup_intern_concrete_type(&self, id: &ConcreteTypeId) -> SierraGeneratorTypeLongId {
let concrete_type = self
.program
/// Get the type declaration for a given `type_id`.
fn get_type_declaration(&self, type_id: &ConcreteTypeId) -> TypeDeclaration {
self.program
.type_declarations
.iter()
.find(|f| f.id.id == id.id)
.find(|f| f.id.id == type_id.id)
.expect("ConcreteTypeId should be found in type_declarations.")
.clone();
SierraGeneratorTypeLongId::Regular(Arc::new(concrete_type.long_id))
.clone()
}

/// This function builds the HashSet of type dependencies for `type_id`. The
/// argument `visited_types` is used to keep track of previously visited
/// dependencies to break cycles and avoid infinite recursion.
fn type_dependencies(
&self,
visited_types: &mut HashSet<ConcreteTypeId>,
type_id: &ConcreteTypeId,
) -> HashSet<ConcreteTypeId> {
let mut dependencies = HashSet::new();

if visited_types.contains(type_id) {
return dependencies;
}
visited_types.insert(type_id.clone());

let concrete_type = self.get_type_declaration(type_id);

concrete_type
.long_id
.generic_args
.iter()
.for_each(|t| match t {
program::GenericArg::Type(concrete_type_id) => {
dependencies.insert(concrete_type_id.clone());
if visited_types.contains(concrete_type_id) {
return;
}
dependencies.extend(self.type_dependencies(visited_types, concrete_type_id));
return;
}
_ => return,
});

dependencies
}

/// Returns true if `type_id` depends on `needle`. False otherwise.
fn has_in_deps(&self, type_id: &ConcreteTypeId, needle: &ConcreteTypeId) -> bool {
let mut visited_types = HashSet::new();
let deps = self.type_dependencies(&mut visited_types, type_id);
if deps.contains(&needle) {
return true;
}
return false;
}

/// Get the long debug name for the type with id equivalent to `id`.
///
/// If `id` is a self-referencing type (i.e. it depends on itself), then the
/// function returns `None` as an alternative to
/// [`SierraGeneratorTypeLongId::CircuitBreaker`]. It's not possible to
/// construct a [`SierraGeneratorTypeLongId::CircuitBreaker`] object because
/// it requires having access to the SalsaDB of the program.
fn lookup_intern_concrete_type(
&self,
id: &ConcreteTypeId,
) -> Option<SierraGeneratorTypeLongId> {
let concrete_type = self.get_type_declaration(id);
if self.has_in_deps(id, id) {
None
} else {
Some(SierraGeneratorTypeLongId::Regular(Arc::new(
concrete_type.long_id,
)))
}
}
}
impl SierraIdReplacer for DebugReplacer {

impl SierraIdReplacer for DebugReplacer<'_> {
fn replace_libfunc_id(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncId {
let mut long_id = self.lookup_intern_concrete_lib_func(id);
self.replace_generic_args(&mut long_id.generic_args);
Expand All @@ -91,10 +158,7 @@ impl SierraIdReplacer for DebugReplacer {
// It's not possible to recover the `debug_name` of `Phantom` and `CycleBreaker` because
// it relies on access to the Salsa db which is available only during
// contract compilation.
SierraGeneratorTypeLongId::Phantom(_) | SierraGeneratorTypeLongId::CycleBreaker(_) => {
id.clone()
}
SierraGeneratorTypeLongId::Regular(long_id) => {
Some(SierraGeneratorTypeLongId::Regular(long_id)) => {
let mut long_id = long_id.as_ref().clone();
self.replace_generic_args(&mut long_id.generic_args);
if long_id.generic_id == "Enum".into() || long_id.generic_id == "Struct".into() {
Expand All @@ -116,6 +180,7 @@ impl SierraIdReplacer for DebugReplacer {
debug_name: Some(long_id.to_string().into()),
}
}
_ => id.clone(),
}
}

Expand Down Expand Up @@ -149,10 +214,7 @@ impl SierraIdReplacer for DebugReplacer {
/// [`cairo_lang_sierra_generator::db::SierraGenGroup`] trait object.
#[must_use]
pub fn replace_sierra_ids_in_program(program: &Program) -> Program {
DebugReplacer {
program: program.clone(),
}
.apply(program)
DebugReplacer { program }.apply(program)
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion starknet-replay/src/runner/replay_class_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct ReplayClassHash {

/// The type [`VisitedPcs`] is a hashmap to store the visited program counters
/// for each contract invocation during replay.
pub type VisitedPcs = HashMap<ReplayClassHash, Vec<Vec<usize>>>;
pub type VisitedPcs = HashMap<ReplayClassHash, Vec<usize>>;

/// The type [`TransactionOutput`] contains the combination of transaction
/// receipt and list of visited program counters.
Expand Down
4 changes: 2 additions & 2 deletions starknet-replay/src/storage/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,15 @@ impl ReplayStorage for RpcStorage {

let visited_pcs: VisitedPcs = state
.visited_pcs
.clone()
.0
.to_owned()
.into_iter()
.map(|(class_hash, pcs)| {
let replay_class_hash = ReplayClassHash {
block_number,
class_hash,
};
(replay_class_hash, pcs.into_iter().collect())
(replay_class_hash, pcs.clone())
})
.collect();
if let Some(filename) = trace_out {
Expand Down
14 changes: 5 additions & 9 deletions starknet-replay/src/storage/rpc/visited_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@ use starknet_api::core::ClassHash;

/// The hashmap of [`VisitedPcsRaw`] is a map from a
/// [`starknet_api::core::ClassHash`] to a vector of visited program counters.
/// The vector returned from each call to [`starknet_api::core::ClassHash`] is
/// added to the nested vector.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct VisitedPcsRaw(pub HashMap<ClassHash, Vec<Vec<usize>>>);
pub struct VisitedPcsRaw(pub HashMap<ClassHash, Vec<usize>>);
impl VisitedPcs for VisitedPcsRaw {
type Pcs = Vec<Vec<usize>>;
type Pcs = Vec<usize>;

fn new() -> Self {
VisitedPcsRaw(HashMap::default())
}

fn insert(&mut self, class_hash: &ClassHash, pcs: &[usize]) {
self.0.entry(*class_hash).or_default().push(pcs.to_vec());
self.0.entry(*class_hash).or_default().extend(pcs.iter());
}

fn iter(&self) -> impl Iterator<Item = (&ClassHash, &Self::Pcs)> {
Expand All @@ -41,15 +39,13 @@ impl VisitedPcs for VisitedPcsRaw {

fn to_set(pcs: Self::Pcs) -> HashSet<usize> {
let mut set = HashSet::new();
pcs.into_iter().flatten().for_each(|p| {
pcs.into_iter().for_each(|p| {
set.insert(p);
});
set
}

fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::Pcs) {
for pc in pcs {
state.add_visited_pcs(*class_hash, &pc);
}
state.add_visited_pcs(*class_hash, &pcs);
}
}

0 comments on commit 91a761b

Please sign in to comment.