Skip to content
This repository has been archived by the owner on Jun 4, 2023. It is now read-only.

Commit

Permalink
tf: add operation for unloading a TF session
Browse files Browse the repository at this point in the history
Also, rename current TF operations to reflect their actual actions.

Signed-off-by: Babis Chalios <[email protected]>
  • Loading branch information
bchalios committed Sep 13, 2021
1 parent 1e03308 commit 41e2ce0
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 34 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vaccel"
version = "0.1.0"
version = "0.2.0"
authors = ["Babis Chalios <[email protected]>"]
edition = "2018"
build = "build.rs"
Expand All @@ -11,7 +11,7 @@ name = "vaccel"
path = "src/lib.rs"

[dependencies]
protocols = { git = "https://github.com/cloudkernels/vaccel-grpc", tag = "v0.1.0" }
protocols = { git = "https://github.com/cloudkernels/vaccel-grpc", tag = "v0.2.0" }
protobuf = "=2.14.0"
memchr = "2.3.3"
libc = "0.2.68"
Expand Down
39 changes: 22 additions & 17 deletions examples/tf_inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() -> utilities::Result<()> {
info!("Registered model {} with session {}", model.id(), sess.id());

// Load model graph
if let Err(err) = model.load_graph(&mut sess) {
if let Err(err) = model.session_load(&mut sess) {
error!("Could not load graph for model {}: {}", model.id(), err);

info!("Destroying session {}", sess.id());
Expand All @@ -45,24 +45,29 @@ fn main() -> utilities::Result<()> {
sess_args.add_input(&in_node, &in_tensor);
sess_args.request_output(&out_node);

let result = model.inference(&mut sess, &mut sess_args)?;

let out = result.get_output::<f32>(0).unwrap();

println!("Success!");
println!(
"Output tensor => type:{:?} nr_dims:{}",
out.data_type(),
out.nr_dims()
);
for i in 0..out.nr_dims() {
println!("dim[{}]: {}", i, out.dim(i as usize).unwrap());
}
println!("Result Tensor :");
for i in 0..10 {
println!("{:.6}", out[i]);
let result = model.session_run(&mut sess, &mut sess_args)?;

match result.get_output::<f32>(0) {
Ok(out) => {
println!("Success!");
println!(
"Output tensor => type:{:?} nr_dims:{}",
out.data_type(),
out.nr_dims()
);
for i in 0..out.nr_dims() {
println!("dim[{}]: {}", i, out.dim(i as usize).unwrap());
}
println!("Result Tensor :");
for i in 0..10 {
println!("{:.6}", out[i]);
}
}
Err(err) => println!("Inference failed: '{}'", err),
}

model.session_delete(&mut sess)?;

sess.close()?;

Ok(())
Expand Down
20 changes: 14 additions & 6 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5117,12 +5117,13 @@ pub const VACCEL_TF_MODEL_NEW: vaccel_op_type = 6;
pub const VACCEL_TF_MODEL_DESTROY: vaccel_op_type = 7;
pub const VACCEL_TF_MODEL_REGISTER: vaccel_op_type = 8;
pub const VACCEL_TF_MODEL_UNREGISTER: vaccel_op_type = 9;
pub const VACCEL_TF_MODEL_LOAD_GRAPH: vaccel_op_type = 10;
pub const VACCEL_TF_MODEL_RUN_GRAPH: vaccel_op_type = 11;
pub const VACCEL_FUNCTIONS_NR: vaccel_op_type = 12;
pub const VACCEL_TF_SESSION_LOAD: vaccel_op_type = 10;
pub const VACCEL_TF_SESSION_RUN: vaccel_op_type = 11;
pub const VACCEL_TF_SESSION_DELETE: vaccel_op_type = 12;
pub const VACCEL_FUNCTIONS_NR: vaccel_op_type = 13;
pub type vaccel_op_type = u32;
extern "C" {
pub static mut vaccel_op_name: [*const ::std::os::raw::c_char; 12usize];
pub static mut vaccel_op_name: [*const ::std::os::raw::c_char; 13usize];
}
extern "C" {
pub fn vaccel_sgemm(
Expand Down Expand Up @@ -5541,14 +5542,14 @@ impl Default for vaccel_tf_status {
}
}
extern "C" {
pub fn vaccel_tf_model_load_graph(
pub fn vaccel_tf_session_load(
session: *mut vaccel_session,
model: *mut vaccel_tf_saved_model,
status: *mut vaccel_tf_status,
) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn vaccel_tf_model_run(
pub fn vaccel_tf_session_run(
session: *mut vaccel_session,
model: *const vaccel_tf_saved_model,
run_options: *const vaccel_tf_buffer,
Expand All @@ -5561,6 +5562,13 @@ extern "C" {
status: *mut vaccel_tf_status,
) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn vaccel_tf_session_delete(
session: *mut vaccel_session,
model: *mut vaccel_tf_saved_model,
status: *mut vaccel_tf_status,
) -> ::std::os::raw::c_int;
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vaccel_tf_saved_model {
Expand Down
32 changes: 24 additions & 8 deletions src/ops/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl InferenceResult {
}

impl SavedModel {
/// Load a TensorFlow graph from a model
/// Load a TensorFlow session from a SavedModel
///
/// The TensorFlow model must have been created and registered to
/// a session. The operation will load the graph and keep the graph
Expand All @@ -142,32 +142,32 @@ impl SavedModel {
/// * `session` - The session in the context of which we perform the operation. The model needs
/// to be registered with this session.
///
pub fn load_graph(&mut self, sess: &mut Session) -> Result<tf::Status> {
pub fn session_load(&mut self, sess: &mut Session) -> Result<tf::Status> {
let mut status = tf::Status::new();

match unsafe {
ffi::vaccel_tf_model_load_graph(sess.inner_mut(), self.inner_mut(), status.inner_mut())
ffi::vaccel_tf_session_load(sess.inner_mut(), self.inner_mut(), status.inner_mut())
as u32
} {
ffi::VACCEL_OK => Ok(status),
err => Err(Error::Runtime(err)),
}
}

/// Run inference on a TensorFlow model
/// Run a TensorFlow session
///
/// This will run inference using a TensorFlow graph that has been previously loaded
/// using `vaccel_tf_model::load_graph`.
/// This will run using a TensorFlow session that has been previously loaded
/// using `vaccel_tf_model::load_session`.
///
pub fn inference(
pub fn session_run(
&mut self,
sess: &mut Session,
args: &mut InferenceArgs,
) -> Result<InferenceResult> {
let mut result = InferenceResult::new(args.out_nodes.len());

match unsafe {
ffi::vaccel_tf_model_run(
ffi::vaccel_tf_session_run(
sess.inner_mut(),
self.inner_mut(),
args.run_options,
Expand All @@ -184,4 +184,20 @@ impl SavedModel {
err => Err(Error::Runtime(err)),
}
}

/// Delete a TensorFlow session
///
/// This will unload a TensorFlow session that was previously loaded in memory
/// using `vaccel_tf_model::load_session`.
pub fn session_delete(&mut self, sess: &mut Session) -> Result<()> {
let mut status = tf::Status::new();

match unsafe {
ffi::vaccel_tf_session_delete(sess.inner_mut(), self.inner_mut(), status.inner_mut())
as u32
} {
ffi::VACCEL_OK => Ok(()),
err => Err(Error::Runtime(err)),
}
}
}
2 changes: 1 addition & 1 deletion vaccelrt

0 comments on commit 41e2ce0

Please sign in to comment.