Skip to content

Commit

Permalink
add the parser for conditional reqs
Browse files Browse the repository at this point in the history
  • Loading branch information
prsabahrami committed Jan 27, 2025
1 parent 3ddb16c commit 07e13e3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 52 deletions.
13 changes: 11 additions & 2 deletions src/requirement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ pub struct ConditionalRequirement {

impl ConditionalRequirement {
/// Creates a new conditional requirement.
pub fn new(condition: VersionSetId, requirement: Requirement) -> Self {
pub fn new(condition: Option<VersionSetId>, requirement: Requirement) -> Self {
Self {
condition: Some(condition),
condition,
requirement,
}
}
Expand Down Expand Up @@ -62,6 +62,15 @@ impl From<VersionSetId> for ConditionalRequirement {
}
}

impl From<VersionSetUnionId> for ConditionalRequirement {
fn from(value: VersionSetUnionId) -> Self {
Self {
condition: None,
requirement: value.into(),
}
}
}

/// Specifies the dependency of a solvable on a set of version sets.
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
Expand Down
119 changes: 69 additions & 50 deletions tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,16 @@ impl FromStr for Pack {
struct Spec {
name: String,
versions: Ranges<Pack>,
condition: Option<Box<Spec>>,
}

impl Spec {
pub fn new(name: String, versions: Ranges<Pack>) -> Self {
Self { name, versions }
pub fn new(name: String, versions: Ranges<Pack>, condition: Option<Box<Spec>>) -> Self {
Self {
name,
versions,
condition,
}
}

pub fn parse_union(
Expand All @@ -131,11 +136,23 @@ impl FromStr for Spec {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
let split = s.split(' ').collect::<Vec<_>>();
let name = split
.first()
.expect("spec does not have a name")
.to_string();
let split = s.split(';').collect::<Vec<_>>(); // c 1; if b 1..2

if split.len() == 1 {
// c 1
let split = s.split(' ').collect::<Vec<_>>();
let name = split
.first()
.expect("spec does not have a name")
.to_string();
let versions = version_range(split.get(1));
return Ok(Spec::new(name, versions, None));
}

let binding = split.get(1).unwrap().replace("if", "");
let condition = Spec::parse_union(&binding).next().unwrap().unwrap();

let spec = Spec::from_str(split.first().unwrap()).unwrap();

fn version_range(s: Option<&&str>) -> Ranges<Pack> {
if let Some(s) = s {
Expand All @@ -154,9 +171,11 @@ impl FromStr for Spec {
}
}

let versions = version_range(split.get(1));

Ok(Spec::new(name, versions))
Ok(Spec::new(
spec.name,
spec.versions,
Some(Box::new(condition)),
))
}
}

Expand Down Expand Up @@ -500,24 +519,47 @@ impl DependencyProvider for BundleBoxProvider {
.intern_version_set(first_name, first.versions.clone());

let requirement = if remaining_req_specs.len() == 0 {
first_version_set.into()
if let Some(condition) = &first.condition {
ConditionalRequirement::new(
Some(self.intern_version_set(condition)),
first_version_set.into(),
)
} else {
first_version_set.into()
}
} else {
let other_version_sets = remaining_req_specs.map(|spec| {
self.pool.intern_version_set(
// Check if all specs have the same condition
let common_condition = first.condition.as_ref().map(|c| self.intern_version_set(c));

// Collect version sets for union
let mut version_sets = vec![first_version_set];
for spec in remaining_req_specs {
// Verify condition matches
if spec.condition.as_ref().map(|c| self.intern_version_set(c))
!= common_condition
{
panic!("All specs in a union must have the same condition");
}

version_sets.push(self.pool.intern_version_set(
self.pool.intern_package_name(&spec.name),
spec.versions.clone(),
)
});

self.pool
.intern_version_set_union(first_version_set, other_version_sets)
.into()
));
}

// Create union and wrap in conditional if needed
let union = self
.pool
.intern_version_set_union(version_sets[0], version_sets.into_iter().skip(1));

if let Some(condition) = common_condition {
ConditionalRequirement::new(Some(condition), union.into())
} else {
union.into()
}
};

result.requirements.push(ConditionalRequirement {
requirement,
condition: None,
});
result.requirements.push(requirement);
}

for req in &deps.constrains {
Expand Down Expand Up @@ -1440,18 +1482,8 @@ fn test_conditional_requirements() {
provider.add_package("b", 1.into(), &[], &[]); // Simple package b
provider.add_package("c", 1.into(), &[], &[]); // Simple package c

// Create conditional requirement: if b=1 is installed, require c
let b_spec = Spec::parse_union("b 1").next().unwrap().unwrap();
let c_spec = Spec::parse_union("c 1").next().unwrap().unwrap();

let b_version_set = provider.intern_version_set(&b_spec);
let c_version_set = provider.intern_version_set(&c_spec);

let conditional_req = ConditionalRequirement::new(b_version_set, c_version_set.into());

// Create problem with both regular and conditional requirements
let mut requirements = provider.requirements(&["a"]);
requirements.push(conditional_req);
let requirements = provider.requirements(&["a", "c 1; if b 1..2"]);

let mut solver = Solver::new(provider);
let problem = Problem::new().requirements(requirements);
Expand All @@ -1470,31 +1502,18 @@ fn test_conditional_requirements_not_met() {
provider.add_package("b", 1.into(), &[], &[]); // Add b=1 as a candidate
provider.add_package("b", 2.into(), &[], &[]); // Different version of b
provider.add_package("c", 1.into(), &[], &[]); // Simple package c
provider.add_package("a", 1.into(), &["b 2"], &[]); // a depends on b

// Create conditional requirement: if b=1 is installed, require c
let b_spec = Spec::parse_union("b 1").next().unwrap().unwrap();
let c_spec = Spec::parse_union("c 1").next().unwrap().unwrap();

let b_version_set = provider.intern_version_set(&b_spec);
let c_version_set = provider.intern_version_set(&c_spec);

let conditional_req = ConditionalRequirement::new(b_version_set, c_version_set.into());

// Create problem with just the conditional requirement
let mut requirements = vec![conditional_req];

// Add a requirement for b=2 to ensure we get a version that doesn't trigger the condition
let b2_spec = Spec::parse_union("b 2").next().unwrap().unwrap();
let b2_version_set = provider.intern_version_set(&b2_spec);
requirements.push(b2_version_set.into());
let requirements = provider.requirements(&["a", "c 1; if b 1"]);

let mut solver = Solver::new(provider);
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let result = transaction_to_string(solver.provider(), &solved);
// Since b=1 is not installed (b=2 is), c should not be installed
insta::assert_snapshot!(result, @r###"
b=2
a=1
"###);
}

Expand Down

0 comments on commit 07e13e3

Please sign in to comment.