Skip to content

Commit

Permalink
Merge pull request #47 from contramundum53/batch-eval
Browse files Browse the repository at this point in the history
Add batch-evaluate mode
  • Loading branch information
HideakiImamura authored Jun 25, 2024
2 parents 83ae97e + 80e3e73 commit bd67e50
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
91 changes: 91 additions & 0 deletions src/batch_eval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//! `kurobako run` command.
use crate::problem::KurobakoProblemRecipe;
use crate::solver::KurobakoSolverRecipe;
use kurobako_core::json;
use kurobako_core::problem::ProblemRecipe as _;
use kurobako_core::problem::{Evaluator as _, Problem as _, ProblemFactory as _};
use kurobako_core::registry::FactoryRegistry;
use kurobako_core::rng::ArcRng;
use kurobako_core::trial::{Params, Values};
use kurobako_core::{ErrorKind, Result};
use serde::Deserialize;
use serde::Serialize;
use std::io;
use structopt::StructOpt;
use serde_json::Error;
use std::io::Write;

/// Options of the `kurobako batch-evaluate` command.
#[derive(Debug, Clone, StructOpt)]
#[structopt(rename_all = "kebab-case")]
pub struct BatchEvaluateOpt {
/// Evaluation target problem.
#[structopt(long, parse(try_from_str = json::parse_json))]
pub problem: KurobakoProblemRecipe,

/// Random seed.
#[structopt(long)]
pub seed: Option<u64>,
}

#[derive(Debug, Clone, Deserialize)]
struct EvalCall {
params: Params,
step: Option<u64>
}

#[derive(Debug, Clone, Serialize)]
struct EvalReply {
values: Values,
}

impl BatchEvaluateOpt {
/// Evaluates the given parameters.
pub fn run(&self) -> Result<()> {
let random_seed = self.seed.unwrap_or_else(rand::random);
let rng = ArcRng::new(random_seed);
let registry = FactoryRegistry::new::<KurobakoProblemRecipe, KurobakoSolverRecipe>();
let problem_factory = track!(self.problem.create_factory(&registry))?;
let problem_spec = track!(problem_factory.specification())?;

let problem = track!(problem_factory.create_problem(rng))?;
let mut writer = io::stdout();
loop{
let mut line = String::new();
let n = io::stdin().read_line(&mut line)?;
if n == 0 {
break;
}
let EvalCall { params, step } = serde_json::from_str(&line).map_err(Error::from)?;

track_assert_eq!(
params.len(),
problem_spec.params_domain.variables().len(),
ErrorKind::InvalidInput
);


let evaluator_or_error = track!(problem.create_evaluator(params.clone()));

let values = match evaluator_or_error {
Ok(mut evaluator) => {
let s = step.unwrap_or_else(|| problem_spec.steps.last());
let (_, values) = track!(evaluator.evaluate(s))?;
values
},
Err(e) => {
if *e.kind() != ErrorKind::UnevaluableParams {
return Err(e);
} else {
Values::new(vec![])
}
}
};

serde_json::to_writer(&mut writer, &EvalReply{values}).map_err(Error::from)?;
writer.write("\n".as_bytes())?;
writer.flush()?;
}
Ok(())
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ macro_rules! track_writeln {
}
}

pub mod batch_eval;
pub mod dataset;
pub mod evaluate;
pub mod plot;
Expand Down
7 changes: 7 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[macro_use]
extern crate trackable;

use kurobako::batch_eval::BatchEvaluateOpt;
use kurobako::dataset::DatasetOpt;
use kurobako::evaluate::EvaluateOpt;
use kurobako::plot::PlotOpt;
Expand Down Expand Up @@ -62,6 +63,9 @@ enum Opt {
/// Evaluates parameters of a problem.
Evaluate(EvaluateOpt),

/// Evaluates a set of parameters of a problem through stdio.
BatchEvaluate(BatchEvaluateOpt),

/// Show problem or solver specification.
Spec(SpecOpt),
}
Expand Down Expand Up @@ -114,6 +118,9 @@ fn main() -> trackable::result::TopLevelResult {
let spec = track!(opt.get_spec())?;
print_json!(spec);
}
Opt::BatchEvaluate(opt) => {
track!(opt.run())?;
}
}

Ok(())
Expand Down

0 comments on commit bd67e50

Please sign in to comment.