Skip to content

Commit

Permalink
#5 implements first set of get methods for client
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinmead committed Dec 15, 2024
1 parent 8a16dfd commit 286ef03
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ log = "0.4.22"
[dev-dependencies]
rstest = "0.23.0"
testcontainers = { version = "=0.23.1", features = ["blocking"] }
rand = "0.8.5"
107 changes: 100 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
//! ```
//!
use crate::MLFlowError::{ExperimentBuilderError, UnknownError};
use log::debug;
use reqwest::blocking::Client;
use crate::MLFlowError::{ExperimentBuilderError, ExperimentNotFound, UnknownError};
use reqwest::blocking::{Client, Response};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};

pub type MLFlowResult<T> = Result<T, MLFlowError>;
Expand All @@ -22,6 +22,9 @@ pub enum MLFlowError {
#[error("ExperimentBuilderError: {0}")]
ExperimentBuilderError(String),

#[error("{0}")]
ExperimentNotFound(String),

#[error("ClientError: {0}")]
ClientError(String),

Expand Down Expand Up @@ -52,8 +55,17 @@ struct CreateExperimentResponse {
experiment_id: String,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
struct GetExperimentResponse {
experiment: Experiment,
}

trait MLFlowClient {
fn create_experiment(&self, experiment: Experiment) -> MLFlowResult<CreateExperimentResponse>;

fn get_experiment_by_id(&self, id: impl AsRef<str>) -> MLFlowResult<GetExperimentResponse>;

fn get_experiment_by_name(&self, name: impl AsRef<str>) -> MLFlowResult<GetExperimentResponse>;
}

#[derive(Clone, Debug, Default)]
Expand All @@ -71,6 +83,35 @@ impl MLFLowRestClient {
host: host.as_ref().to_string(),
}
}

fn _process_get(
&self,
result: Result<Response, reqwest::Error>,
) -> MLFlowResult<GetExperimentResponse> {
match result {
Ok(r) => {
if r.status().is_success() {
let e = r.json::<GetExperimentResponse>();
match e {
Ok(result) => Ok(result),
Err(e) => {
println!("{}", e);
Err(UnknownError(e.to_string()))
}
}
} else if r.status() == StatusCode::NOT_FOUND {
Err(ExperimentNotFound("experiment was not found".to_string()))
} else {
println!("experiment not found server message: {}", r.status());
Err(UnknownError("error finding experiment".to_string()))
}
}
Err(e) => {
println!("{}", e);
Err(UnknownError(e.to_string()))
}
}
}
}

impl MLFlowClient for MLFLowRestClient {
Expand All @@ -92,11 +133,34 @@ impl MLFlowClient for MLFLowRestClient {
}
}
Err(result) => {
debug!("{}", result.to_string());
println!("{}", result);
Err(UnknownError(result.to_string()))
}
}
}

fn get_experiment_by_id(&self, id: impl AsRef<str>) -> MLFlowResult<GetExperimentResponse> {
let url = format!("{}{}", &self.host, "/api/2.0/mlflow/experiments/get");
let result = self
.client
.get(url)
.query(&[("experiment_id", id.as_ref())])
.send();
self._process_get(result)
}

fn get_experiment_by_name(&self, name: impl AsRef<str>) -> MLFlowResult<GetExperimentResponse> {
let url = format!(
"{}{}",
&self.host, "/api/2.0/mlflow/experiments/get-by-name"
);
let result = self
.client
.get(url)
.query(&[("experiment_name", name.as_ref())])
.send();
self._process_get(result)
}
}

pub trait ExperimentIdentifier {
Expand Down Expand Up @@ -182,9 +246,38 @@ impl ExperimentBuilder {
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
struct Version {
version: String,
pub enum ExperimentIdentifierType {
ById(String),
ByName(String),
}

#[derive(Default)]
pub struct ExperimentLoader {
client: Option<MLFLowRestClient>,
}

impl ExperimentLoader {
pub fn with_client(mut self, client: MLFLowRestClient) -> Self {
self.client = Some(client);
self
}

pub fn load(self, experiment_identifier: ExperimentIdentifierType) -> MLFlowResult<Experiment> {
let client: MLFLowRestClient = self
.client
.unwrap_or_else(|| MLFLowRestClient::new("http://localhost:5000"));

match experiment_identifier {
ExperimentIdentifierType::ById(id) => match client.get_experiment_by_id(id) {
Ok(resp) => Ok(resp.experiment),
Err(e) => Err(UnknownError(e.to_string())),
},
ExperimentIdentifierType::ByName(name) => match client.get_experiment_by_name(name) {
Ok(resp) => Ok(resp.experiment),
Err(e) => Err(e),
},
}
}
}

#[cfg(test)]
Expand Down
99 changes: 98 additions & 1 deletion tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use rstest::*;
use testcontainers::core::ContainerPort::Tcp;
use testcontainers::{
core::WaitFor, runners::SyncRunner, ContainerRequest, GenericImage, ImageExt,
};

use mlflow_rs::{ExperimentBuilder, ExperimentIdentifier, MLFLowRestClient};
use mlflow_rs::ExperimentIdentifierType::{ById, ByName};
use mlflow_rs::{
ExperimentBuilder, ExperimentIdentifier, ExperimentLoader, MLFLowRestClient, MLFlowError,
};

const MLFLOW_VERSION: &str = "2.18.0";
const MLFLOW_DOCKER_IMAGE: &str = "ghcr.io/mlflow/mlflow";
Expand All @@ -25,6 +30,17 @@ fn mlflow_server_container() -> ContainerRequest<GenericImage> {
container
}

#[fixture]
fn experiment_name() -> String {
let rand_string: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect();

format!("experiment-{}", rand_string)
}

#[rstest]
fn test_fixture_ok(mlflow_server_container: ContainerRequest<GenericImage>) {
let container = mlflow_server_container.start().unwrap();
Expand Down Expand Up @@ -57,3 +73,84 @@ fn test_create_experiment(mlflow_server_container: ContainerRequest<GenericImage
let id = experiment.clone().unwrap().experiment_id().unwrap();
println!("{}", id);
}

#[rstest]
fn test_create_and_get_by_id_ok(
mlflow_server_container: ContainerRequest<GenericImage>,
experiment_name: String,
) {
let container = mlflow_server_container.start().unwrap();
let host_port = container.get_host_port_ipv4(Tcp(5000)).unwrap();
let url = format!("http://localhost:{}", host_port);
println!("{}", experiment_name.clone());

let client = MLFLowRestClient::new(url);
let experiment = ExperimentBuilder::new(experiment_name)
.unwrap()
.with_tag(("tag1", "value1"))
.with_rest_client(client.clone())
.build();
assert!(experiment.clone().is_ok());

let id = experiment.clone().unwrap().experiment_id().unwrap();
println!("{}", id.clone());

let found_experiment = ExperimentLoader::default()
.with_client(client.clone())
.load(ById(id.clone()));
assert!(found_experiment.clone().is_ok());

let found_experiment_id = found_experiment.unwrap().experiment_id().unwrap();

assert_eq!(found_experiment_id, id);
}

#[rstest]
fn test_create_and_get_by_name_ok(
mlflow_server_container: ContainerRequest<GenericImage>,
experiment_name: String,
) {
let container = mlflow_server_container.start().unwrap();
let host_port = container.get_host_port_ipv4(Tcp(5000)).unwrap();
let url = format!("http://localhost:{}", host_port);
println!("{}", experiment_name.clone());

let client = MLFLowRestClient::new(url);
let experiment = ExperimentBuilder::new(experiment_name.clone())
.unwrap()
.with_tag(("tag1", "value1"))
.with_rest_client(client.clone())
.build();
assert!(experiment.clone().is_ok());

let id = experiment.clone().unwrap().experiment_id().unwrap();
println!("{}", id.clone());

let found_experiment = ExperimentLoader::default()
.with_client(client.clone())
.load(ByName(experiment_name.clone()));
assert!(found_experiment.clone().is_ok());

let found_experiment_id = found_experiment.unwrap().experiment_id().unwrap();
assert_eq!(found_experiment_id, id);
}

#[rstest]
fn test_get_by_name_not_found(
mlflow_server_container: ContainerRequest<GenericImage>,
experiment_name: String,
) {
let container = mlflow_server_container.start().unwrap();
let host_port = container.get_host_port_ipv4(Tcp(5000)).unwrap();
let url = format!("http://localhost:{}", host_port);
println!("{}", experiment_name.clone());

let client = MLFLowRestClient::new(url);
let found_experiment = ExperimentLoader::default()
.with_client(client.clone())
.load(ByName(experiment_name.clone()));
assert!(matches!(
found_experiment.err().unwrap(),
MLFlowError::ExperimentNotFound(s) if s == "experiment was not found"
));
}

0 comments on commit 286ef03

Please sign in to comment.