Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Serialized wrapper to make it harder to accidentally deserialize the wrong type #774

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions crates/hyperqueue/src/client/commands/journal/output.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::client::output::json::format_datetime;
use crate::server::event::payload::EventPayload;
use crate::server::event::{bincode_config, Event};
use crate::transfer::messages::JobDescription;
use bincode::Options;
use crate::server::event::Event;
use crate::transfer::messages::{JobDescription, SubmitRequest};
use serde_json::json;
use tako::worker::WorkerOverview;

Expand Down Expand Up @@ -103,14 +102,12 @@ fn format_payload(event: EventPayload) -> serde_json::Value {
closed_job,
serialized_desc,
} => {
let job_desc: JobDescription = bincode_config()
.deserialize(&serialized_desc)
.expect("Invalid job description data");
let submit: SubmitRequest = serialized_desc.deserialize().expect("Invalid submit data");
json!({
"type": "job-created",
"job": job_id,
"closed_job": closed_job,
"desc": JobInfoFormatter(&job_desc).to_json(),
"desc": JobInfoFormatter(&submit.job_desc).to_json(),
})
}
EventPayload::JobCompleted(job_id) => json!({
Expand Down
1 change: 1 addition & 0 deletions crates/hyperqueue/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod parser;
pub mod parser2;
pub mod placeholders;
pub mod rpc;
pub mod serialization;
pub mod serverdir;
pub mod setup;
pub mod utils;
70 changes: 70 additions & 0 deletions crates/hyperqueue/src/common/serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use bincode::Options;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;

/// Helper trait to configure serialization options via separate types.
pub trait SerializationConfig {
fn config() -> impl Options;
}

pub struct DefaultConfig;

impl SerializationConfig for DefaultConfig {
fn config() -> impl Options {
bincode::DefaultOptions::new()
}
}

pub struct TrailingAllowedConfig;

impl SerializationConfig for TrailingAllowedConfig {
fn config() -> impl Options {
bincode::DefaultOptions::new().allow_trailing_bytes()
}
}

/// Strongly typed wrapper over `<T>` serialized with Bincode.
#[derive(Serialize, Deserialize)]
pub struct Serialized<T, C = DefaultConfig> {
#[serde(with = "serde_bytes")]
data: Box<[u8]>,
_phantom: PhantomData<(T, C)>,
}

impl<T, C> Debug for Serialized<T, C> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Serialized {} ({}) byte(s)",
std::any::type_name::<T>(),
self.data.len()
)
}
}

impl<T, C> Clone for Serialized<T, C> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
_phantom: PhantomData,
}
}
}

impl<T: Serialize + DeserializeOwned, C: SerializationConfig> Serialized<T, C> {
pub fn new(value: &T) -> bincode::Result<Self> {
let result = C::config().serialize(value)?;
// Check that we're not reallocating needlessly in `into_boxed_slice`
debug_assert_eq!(result.capacity(), result.len());
Ok(Self {
data: result.into_boxed_slice(),
_phantom: Default::default(),
})
}

pub fn deserialize(&self) -> bincode::Result<T> {
C::config().deserialize(&self.data)
}
}
7 changes: 4 additions & 3 deletions crates/hyperqueue/src/server/event/log/read.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::common::serialization::SerializationConfig;
use crate::server::event::log::HQ_JOURNAL_HEADER;
use crate::server::event::{bincode_config, Event};
use crate::server::event::{Event, EventSerializationConfig};
use crate::HQ_VERSION;
use anyhow::{anyhow, bail};
use bincode::Options;
Expand Down Expand Up @@ -30,7 +31,7 @@ impl JournalReader {
if header != HQ_JOURNAL_HEADER {
bail!("Invalid journal format");
}
let hq_version: String = bincode_config()
let hq_version: String = EventSerializationConfig::config()
.deserialize_from(&mut file)
.map_err(|error| anyhow!("Cannot load HQ event log file header: {error:?}"))?;
if hq_version != HQ_VERSION {
Expand Down Expand Up @@ -61,7 +62,7 @@ impl Iterator for &mut JournalReader {
if self.position == self.size {
return None;
}
match bincode_config().deserialize_from(&mut self.source) {
match EventSerializationConfig::config().deserialize_from(&mut self.source) {
Ok(event) => Some(Ok(event)),
Err(error) => match error.deref() {
bincode::ErrorKind::Io(e)
Expand Down
7 changes: 4 additions & 3 deletions crates/hyperqueue/src/server/event/log/write.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::common::serialization::SerializationConfig;
use crate::server::event::log::HQ_JOURNAL_HEADER;
use crate::server::event::{bincode_config, Event};
use crate::server::event::{Event, EventSerializationConfig};
use crate::HQ_VERSION;
use bincode::Options;
use std::fs::{File, OpenOptions};
Expand Down Expand Up @@ -31,15 +32,15 @@ impl JournalWriter {

if position == 0 && file.stream_position()? == 0 {
file.write_all(HQ_JOURNAL_HEADER)?;
bincode_config().serialize_into(&mut file, HQ_VERSION)?;
EventSerializationConfig::config().serialize_into(&mut file, HQ_VERSION)?;
file.flush()?;
};

Ok(Self { file })
}

pub fn store(&mut self, event: Event) -> anyhow::Result<()> {
bincode_config().serialize_into(&mut self.file, &event)?;
EventSerializationConfig::config().serialize_into(&mut self.file, &event)?;
Ok(())
}

Expand Down
9 changes: 3 additions & 6 deletions crates/hyperqueue/src/server/event/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,19 @@ pub mod log;
pub mod payload;
pub mod streamer;

use bincode::Options;
use crate::stream::StreamSerializationConfig;
use chrono::serde::ts_milliseconds;
use chrono::{DateTime, Utc};
use payload::EventPayload;
use serde::{Deserialize, Serialize};

pub type EventId = u32;

type EventSerializationConfig = StreamSerializationConfig;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Event {
#[serde(with = "ts_milliseconds")]
pub time: DateTime<Utc>,
pub payload: EventPayload,
}

#[inline]
pub(crate) fn bincode_config() -> impl Options {
bincode::DefaultOptions::new().allow_trailing_bytes()
}
9 changes: 5 additions & 4 deletions crates/hyperqueue/src/server/event/payload.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::common::serialization::Serialized;
use crate::server::autoalloc::AllocationId;
use crate::server::autoalloc::QueueId;
use crate::transfer::messages::{AllocationQueueParams, JobDescription};
use crate::transfer::messages::{AllocationQueueParams, JobDescription, SubmitRequest};
use crate::JobId;
use crate::{JobTaskId, WorkerId};
use serde::{Deserialize, Serialize};
Expand All @@ -19,13 +20,13 @@ pub enum EventPayload {
WorkerOverviewReceived(WorkerOverview),
/// A Job was submitted by the user -- full information to reconstruct the job;
/// it will be only stored into file, not held in memory
/// Vec<u8> is serialized JobDescription; the main reason is avoiding duplication of JobDescription
/// Vec<u8> is serialized SubmitRequest; the main reason is avoiding duplication of SubmitRequest
/// (we serialize it before it is stripped down)
/// and a nice side effect is that Events can be deserialized without deserializing a potentially large submit data
Submit {
job_id: JobId,
closed_job: bool,
serialized_desc: Vec<u8>,
serialized_desc: Serialized<SubmitRequest>,
},
/// All tasks of the job have finished.
JobCompleted(JobId),
Expand All @@ -43,7 +44,7 @@ pub enum EventPayload {
job_id: JobId,
task_id: JobTaskId,
},
// Task that failed to execute
/// Task has failed to execute
TaskFailed {
job_id: JobId,
task_id: JobTaskId,
Expand Down
6 changes: 3 additions & 3 deletions crates/hyperqueue/src/server/event/streamer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::common::serialization::Serialized;
use crate::server::autoalloc::{AllocationId, QueueId};
use crate::server::event::log::{EventStreamMessage, EventStreamSender};
use crate::server::event::payload::EventPayload;
use crate::server::event::{bincode_config, Event};
use crate::server::event::Event;
use crate::transfer::messages::{AllocationQueueParams, JobDescription, SubmitRequest};
use crate::{JobId, JobTaskId, WorkerId};
use bincode::Options;
use chrono::Utc;
use smallvec::SmallVec;
use tako::gateway::LostWorkerReason;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl EventStreamer {
self.send_event(EventPayload::Submit {
job_id,
closed_job: submit_request.job_id.is_none(),
serialized_desc: bincode_config().serialize(submit_request)?,
serialized_desc: Serialized::new(submit_request)?,
});
Ok(())
}
Expand Down
5 changes: 1 addition & 4 deletions crates/hyperqueue/src/server/restore.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::server::autoalloc::QueueId;
use crate::server::client::submit_job_desc;
use crate::server::event::bincode_config;
use crate::server::event::log::JournalReader;
use crate::server::event::payload::EventPayload;
use crate::server::job::{Job, JobTaskState, StartedTaskData};
Expand All @@ -10,7 +9,6 @@ use crate::transfer::messages::{
};
use crate::worker::start::RunningTaskContext;
use crate::{JobId, JobTaskId, Map};
use bincode::Options;
use std::path::Path;
use tako::gateway::NewTasksMessage;
use tako::{ItemId, WorkerId};
Expand Down Expand Up @@ -177,8 +175,7 @@ impl StateRestorer {
serialized_desc,
} => {
log::debug!("Replaying: JobTasksCreated {job_id}");
let submit_request: SubmitRequest =
bincode_config().deserialize(&serialized_desc)?;
let submit_request: SubmitRequest = serialized_desc.deserialize()?;
if closed_job {
let mut job = RestorerJob::new(submit_request.job_desc, false);
job.add_submit(submit_request.submit_desc);
Expand Down
4 changes: 4 additions & 0 deletions crates/hyperqueue/src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
use crate::common::serialization::TrailingAllowedConfig;

pub mod reader;

pub type StreamSerializationConfig = TrailingAllowedConfig;
7 changes: 4 additions & 3 deletions crates/hyperqueue/src/stream/reader/outputlog.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::client::commands::outputlog::{CatOpts, Channel, ExportOpts, ShowOpts};
use crate::common::arraydef::IntArray;
use crate::common::error::HqError;
use crate::server::event::bincode_config;
use crate::common::serialization::SerializationConfig;
use crate::stream::StreamSerializationConfig;
use crate::transfer::stream::{ChannelId, StreamChunkHeader};
use crate::worker::streamer::{StreamFileHeader, STREAM_FILE_HEADER, STREAM_FILE_SUFFIX};
use crate::{JobId, JobTaskId, Set};
Expand Down Expand Up @@ -156,7 +157,7 @@ impl OutputLog {
}

fn read_chunk(file: &mut BufReader<File>) -> crate::Result<Option<StreamChunkHeader>> {
match bincode_config().deserialize_from(file) {
match StreamSerializationConfig::config().deserialize_from(file) {
Ok(event) => Ok(Some(event)),
Err(error) => match error.deref() {
bincode::ErrorKind::Io(e)
Expand Down Expand Up @@ -217,7 +218,7 @@ impl OutputLog {
if header != STREAM_FILE_HEADER {
anyhow::bail!("Invalid file format");
}
Ok(bincode_config().deserialize_from(file)?)
Ok(StreamSerializationConfig::config().deserialize_from(file)?)
}

pub fn summary(&self) -> Summary {
Expand Down
7 changes: 4 additions & 3 deletions crates/hyperqueue/src/worker/streamer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::common::error::HqError;
use crate::server::event::bincode_config;
use crate::common::serialization::SerializationConfig;
use crate::stream::StreamSerializationConfig;
use crate::transfer::stream::{ChannelId, StreamChunkHeader};
use crate::WrappedRcRefCell;
use crate::{JobId, JobTaskId, Map};
Expand Down Expand Up @@ -186,15 +187,15 @@ async fn stream_writer(
server_uid: Cow::Borrowed(&streamer.server_uid),
worker_id: streamer.worker_id,
};
bincode_config().serialize_into(&mut buffer, &header)?;
StreamSerializationConfig::config().serialize_into(&mut buffer, &header)?;
};
file.write_all(&buffer).await?;
while let Some(message) = receiver.recv().await {
match message {
StreamerMessage::Write { header, data } => {
log::debug!("Waiting data chunk into stream file");
buffer.clear();
bincode_config().serialize_into(&mut buffer, &header)?;
StreamSerializationConfig::config().serialize_into(&mut buffer, &header)?;
file.write_all(&buffer).await?;
if !data.is_empty() {
file.write_all(&data).await?
Expand Down
Loading