diff --git a/crates/hyperqueue/src/server/job.rs b/crates/hyperqueue/src/server/job.rs index 59d703fc2..906575bc1 100644 --- a/crates/hyperqueue/src/server/job.rs +++ b/crates/hyperqueue/src/server/job.rs @@ -1,9 +1,10 @@ use serde::{Deserialize, Serialize}; +use crate::client::status::get_task_status; use crate::server::Senders; use crate::transfer::messages::{ JobDescription, JobDetail, JobInfo, JobSubmitDescription, JobTaskDescription, TaskIdSelector, - TaskSelector, + TaskSelector, TaskStatusSelector, }; use crate::worker::start::RunningTaskContext; use crate::{make_tako_id, JobId, JobTaskCount, JobTaskId, Map, TakoTaskId, WorkerId}; @@ -160,20 +161,40 @@ impl Job { pub fn make_job_detail(&self, task_selector: Option<&TaskSelector>) -> JobDetail { let (mut tasks, tasks_not_found) = if let Some(selector) = task_selector { - match &selector.id_selector { - TaskIdSelector::All => ( + match (&selector.id_selector, &selector.status_selector) { + (TaskIdSelector::All, TaskStatusSelector::All) => ( self.tasks .iter() .map(|(task_id, info)| (*task_id, info.clone())) .collect(), Vec::new(), ), - TaskIdSelector::Specific(ids) => { + (TaskIdSelector::All, TaskStatusSelector::Specific(status)) => ( + self.tasks + .iter() + .filter_map(|(task_id, info)| { + if status.contains(&get_task_status(&info.state)) { + Some((*task_id, info.clone())) + } else { + None + } + }) + .collect(), + Vec::new(), + ), + (TaskIdSelector::Specific(ids), status) => { let mut not_found = Vec::new(); let mut tasks = Vec::with_capacity(ids.id_count() as usize); for task_id in ids.iter() { - if let Some(task) = self.tasks.get(&JobTaskId::new(task_id)) { - tasks.push((JobTaskId::new(task_id), task.clone())); + if let Some(info) = self.tasks.get(&JobTaskId::new(task_id)) { + if match status { + TaskStatusSelector::All => true, + TaskStatusSelector::Specific(s) => { + s.contains(&get_task_status(&info.state)) + } + } { + tasks.push((JobTaskId::new(task_id), info.clone())); + } } else { not_found.push(JobTaskId::new(task_id)); } diff --git a/tests/job/test_job_cat.py b/tests/job/test_job_cat.py index 8ba6c018e..02aa2d06c 100644 --- a/tests/job/test_job_cat.py +++ b/tests/job/test_job_cat.py @@ -146,7 +146,6 @@ def test_job_cat_status(hq_env: HqEnv): ] ) wait_for_job_state(hq_env, 1, "FAILED") - output = hq_env.command(["job", "cat", "--task-status=finished", "1", "stdout", "--print-task-header"]) assert ( output