diff --git a/src/uu/tsort/src/tsort.rs b/src/uu/tsort/src/tsort.rs index 85380bf403b..838ce477848 100644 --- a/src/uu/tsort/src/tsort.rs +++ b/src/uu/tsort/src/tsort.rs @@ -227,8 +227,8 @@ impl<'input> Graph<'input> { for node in &cycle { show!(TsortError::LoopNode((*node).to_string())); } - let u = cycle[0]; - let v = cycle[1]; + let u = *cycle.last().expect("cycle must be non-empty"); + let v = cycle[0]; self.remove_edge(u, v); if self.indegree(v).unwrap() == 0 { frontier.push_back(v); @@ -236,18 +236,20 @@ impl<'input> Graph<'input> { } fn detect_cycle(&self) -> Vec<&'input str> { - // Sort the nodes just to make this function deterministic. - let mut nodes: Vec<_> = self.nodes.keys().collect(); + let mut nodes: Vec<_> = self.nodes.keys().copied().collect(); nodes.sort_unstable(); - let mut visited = HashSet::new(); - let mut stack = Vec::with_capacity(self.nodes.len()); + let mut visited: HashSet<&'input str> = HashSet::new(); + let mut stack: Vec<&'input str> = Vec::with_capacity(self.nodes.len()); for node in nodes { - if !visited.contains(node) && self.dfs(node, &mut visited, &mut stack) { - return stack; + if visited.contains(&node) { + continue; + } + if let Some(cycle) = self.dfs(node, &mut visited, &mut stack) { + return cycle; } } - unreachable!(); + unreachable!("detect_cycle called only when a cycle exists"); } fn dfs( @@ -255,12 +257,12 @@ impl<'input> Graph<'input> { node: &'input str, visited: &mut HashSet<&'input str>, stack: &mut Vec<&'input str>, - ) -> bool { - if stack.contains(&node) { - return true; + ) -> Option> { + if let Some(pos) = stack.iter().position(|&n| n == node) { + return Some(stack[pos..].to_vec()); } if visited.contains(&node) { - return false; + return None; } visited.insert(node); @@ -268,13 +270,13 @@ impl<'input> Graph<'input> { if let Some(successor_names) = self.nodes.get(node).map(|n| &n.successor_names) { for &successor in successor_names { - if self.dfs(successor, visited, stack) { - return true; + if let Some(cycle) = self.dfs(successor, visited, stack) { + return Some(cycle); } } } stack.pop(); - false + None } } diff --git a/tests/by-util/test_tsort.rs b/tests/by-util/test_tsort.rs index eb1a8630d31..0ae86cdc2ca 100644 --- a/tests/by-util/test_tsort.rs +++ b/tests/by-util/test_tsort.rs @@ -103,7 +103,7 @@ fn test_cycle() { new_ucmd!() .pipe_in("a b b c c d c b") .fails_with_code(1) - .stdout_is("a\nc\nd\nb\n") + .stdout_is("a\nb\nc\nd\n") .stderr_is("tsort: -: input contains a loop:\ntsort: b\ntsort: c\n"); } @@ -119,6 +119,6 @@ fn test_two_cycles() { new_ucmd!() .pipe_in("a b b c c b b d d b") .fails_with_code(1) - .stdout_is("a\nc\nd\nb\n") + .stdout_is("a\nb\nc\nd\n") .stderr_is("tsort: -: input contains a loop:\ntsort: b\ntsort: c\ntsort: -: input contains a loop:\ntsort: b\ntsort: d\n"); }