diff --git a/crates/pyhq/src/client/job.rs b/crates/pyhq/src/client/job.rs index d2c27e73a..1f08eb9c1 100644 --- a/crates/pyhq/src/client/job.rs +++ b/crates/pyhq/src/client/job.rs @@ -14,7 +14,8 @@ use hyperqueue::transfer::messages::{ TaskKind, TaskKindProgram, TaskSelector, TaskStatusSelector, TaskWithDependencies, ToClientMessage, }; -use hyperqueue::{rpc_call, tako, JobTaskCount, JobTaskId, Set}; +use hyperqueue::{rpc_call, tako, JobTaskCount, Set}; +use pyo3::exceptions::PyException; use pyo3::types::PyTuple; use pyo3::{IntoPy, PyAny, PyResult, Python}; use std::collections::HashMap; @@ -101,6 +102,12 @@ pub fn submit_job_impl(py: Python, ctx: ClientContextPtr, job: PyJobDescription) // This code is unreachable as long Python cannot submit into open jobs unreachable!() } + SubmitResponse::NonUniqueTaskId(_) => Err(PyException::new_err( + "Non unique IDs in submitted task graph", + )), + SubmitResponse::InvalidDependencies(_) => { + Err(PyException::new_err("Invalid dependency id")) + } } }) } @@ -344,11 +351,10 @@ pub fn get_failed_tasks_impl( if let Some(job_detail) = job_detail { let mut task_path_map = resolve_task_paths(&job_detail, &response.server_uid); let mut tasks = HashMap::with_capacity(job_detail.tasks.len()); - for task in job_detail.tasks { + for (task_id, task) in job_detail.tasks { match task.state { JobTaskState::Failed { error, .. } => { - let id = task.task_id.as_num(); - let path_ctx = task_path_map.remove(&JobTaskId::from(id)).flatten(); + let path_ctx = task_path_map.remove(&task_id).flatten(); let (stdout, stderr, cwd) = match path_ctx { Some(paths) => ( stdio_to_string(paths.stdout), @@ -359,7 +365,7 @@ pub fn get_failed_tasks_impl( }; tasks.insert( - id, + task_id.as_num(), FailedTaskContext { stdout, stderr,