Skip to content

Commit

Permalink
Merge pull request #51 from HideakiImamura/fix-ci
Browse files Browse the repository at this point in the history
Fix CI
  • Loading branch information
HideakiImamura authored Jun 25, 2024
2 parents bd67e50 + ba1c91f commit ba468f8
Show file tree
Hide file tree
Showing 19 changed files with 70 additions and 76 deletions.
33 changes: 15 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ kurobako_core = { path = "kurobako_core", version = "0.1" }
kurobako_problems = { path = "kurobako_problems", version = "0.1" }
kurobako_solvers = { path = "kurobako_solvers", version = "0.2" }
nasbench = "0.1"
num = "0.3"
num = "0.4"
num-integer = "0.1"
ordered-float = "2"
rand = "0.8"
Expand Down
14 changes: 7 additions & 7 deletions kurobako_core/src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl VariableBuilder {
Self {
name: name.to_owned(),
range: Range::Continuous {
low: std::f64::NEG_INFINITY,
high: std::f64::INFINITY,
low: f64::NEG_INFINITY,
high: f64::INFINITY,
},
distribution: Distribution::Uniform,
constraint: None,
Expand All @@ -75,7 +75,7 @@ impl VariableBuilder {

/// Sets the name of this variable.
pub fn name(mut self, name: &str) -> Self {
self.name = name.to_owned();
name.clone_into(&mut self.name);
self
}

Expand Down Expand Up @@ -121,7 +121,7 @@ impl VariableBuilder {
///
/// This is equivalent to `self.categorical(&["false", "true"])`.
pub fn boolean(self) -> Self {
self.categorical(&["false", "true"])
self.categorical(["false", "true"])
}

/// Sets the range of this variable.
Expand Down Expand Up @@ -242,11 +242,11 @@ fn is_not_finite(x: &f64) -> bool {
}

fn neg_infinity() -> f64 {
std::f64::NEG_INFINITY
f64::NEG_INFINITY
}

fn infinity() -> f64 {
std::f64::INFINITY
f64::INFINITY
}

/// Variable range.
Expand Down Expand Up @@ -396,7 +396,7 @@ mod tests {
let vars = vec![
var("a").continuous(-10.0, 10.0).finish()?,
var("b").discrete(0, 5).finish()?,
var("c").categorical(&["foo", "bar", "baz"]).finish()?,
var("c").categorical(["foo", "bar", "baz"]).finish()?,
];

let constraint = Constraint::new("(a + b) < 2");
Expand Down
2 changes: 1 addition & 1 deletion kurobako_core/src/epi/problem/external_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use structopt::StructOpt;

thread_local! {
static FACTORY_CACHE : RefCell<Option<(Vec<u8>, ExternalProgramProblemFactory)>> =
RefCell::new(None);
const { RefCell::new(None) };
}

/// Recipe for the problem implemented by an external program.
Expand Down
8 changes: 6 additions & 2 deletions kurobako_core/src/trial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ pub struct EvaluatedTrial {
pub struct IdGen {
next: u64,
}
impl Default for IdGen {
fn default() -> Self {
Self::new()
}
}
impl IdGen {
/// Makes a new `IdGen` instance.
pub const fn new() -> Self {
Expand Down Expand Up @@ -201,15 +206,14 @@ impl Deref for Values {

mod nullable_f64_vec {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::f64::NAN;

pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where
D: Deserializer<'de>,
{
let v: Vec<Option<f64>> = Deserialize::deserialize(deserializer)?;
Ok(v.into_iter()
.map(|v| if let Some(v) = v { v } else { NAN })
.map(|v| if let Some(v) = v { v } else { f64::NAN })
.collect())
}

Expand Down
6 changes: 3 additions & 3 deletions kurobako_problems/src/hpobench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ impl ProblemFactory for HpobenchProblemFactory {
arXiv preprint arXiv:1905.04970 (2019).",
)
.attr("github", "https://github.com/automl/nas_benchmarks")
.param(domain::var("activation_fn_1").categorical(&["tanh", "relu"]))
.param(domain::var("activation_fn_2").categorical(&["tanh", "relu"]))
.param(domain::var("activation_fn_1").categorical(["tanh", "relu"]))
.param(domain::var("activation_fn_2").categorical(["tanh", "relu"]))
.param(domain::var("batch_size").discrete(0, 4))
.param(domain::var("dropout_1").discrete(0, 3))
.param(domain::var("dropout_2").discrete(0, 3))
.param(domain::var("init_lr").discrete(0, 6))
.param(domain::var("lr_schedule").categorical(&["cosine", "const"]))
.param(domain::var("lr_schedule").categorical(["cosine", "const"]))
.param(domain::var("n_units_1").discrete(0, 6))
.param(domain::var("n_units_2").discrete(0, 6))
.value(domain::var("Validation MSE").continuous(0.0, f64::INFINITY))
Expand Down
11 changes: 4 additions & 7 deletions kurobako_problems/src/nasbench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ impl Evaluator for NasbenchEvaluator {
/// [nas_cifar10.py]: https://github.com/automl/nas_benchmarks/blob/c1bae6632bf15d45ba49c269c04dbbeb3f0379f0/tabular_benchmarks/nas_cifar10.py
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[allow(missing_docs)]
#[derive(Default)]
pub enum Encoding {
#[default]
A,
B,
C,
Expand All @@ -233,7 +235,7 @@ impl Encoding {
fn common_params() -> Vec<VariableBuilder> {
let mut params = Vec::new();
for i in 0..5 {
params.push(domain::var(&format!("op{}", i)).categorical(&[
params.push(domain::var(&format!("op{}", i)).categorical([
"conv1x1-bn-relu",
"conv3x3-bn-relu",
"maxpool3x3",
Expand Down Expand Up @@ -297,7 +299,7 @@ impl Encoding {
fn edges_a(params: &[f64]) -> HashSet<usize> {
let mut edges = HashSet::new();
for (i, p) in params.iter().enumerate() {
if (*p - 1.0).abs() < std::f64::EPSILON {
if (*p - 1.0).abs() < f64::EPSILON {
edges.insert(i);
}
}
Expand Down Expand Up @@ -342,11 +344,6 @@ impl FromStr for Encoding {
}
}
}
impl Default for Encoding {
fn default() -> Self {
Encoding::A
}
}

/// Evaluation metric.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down
13 changes: 8 additions & 5 deletions kurobako_problems/src/sigopt/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use super::bessel::bessel0;
use kurobako_core::{ErrorKind, Result};
use std::f64::consts::PI;
use std::f64::EPSILON;
use std::fmt;
use std::iter;

Expand Down Expand Up @@ -122,7 +121,7 @@ impl TestFunction for Csendes {

fn evaluate(&self, xs: &[f64]) -> f64 {
xs.iter()
.map(|&x| x.powi(6) * (2.0 + (1.0 / (x + EPSILON)).sin()))
.map(|&x| x.powi(6) * (2.0 + (1.0 / (x + f64::EPSILON)).sin()))
.sum()
}
}
Expand Down Expand Up @@ -394,7 +393,7 @@ impl McCourtBase {
e_mat: &'static [&'static [f64]],
) -> impl 'a + Iterator<Item = f64> {
e_mat.iter().zip(centers.iter()).map(move |(evec, center)| {
let mut max = std::f64::NEG_INFINITY;
let mut max = f64::NEG_INFINITY;
for x in xs
.iter()
.zip(center.iter())
Expand Down Expand Up @@ -2374,7 +2373,7 @@ mod tests {
fn shekel05_works() {
assert_eq!(
Shekel05.evaluate(&[4.0, 4.0, 4.0, 4.0]),
-10.152719932456289
-10.152_719_932_456_29
);
}

Expand Down Expand Up @@ -2402,7 +2401,11 @@ mod tests {
#[test]
fn styblinskitang_works() {
assert_eq!(
StyblinskiTang.evaluate(&[-2.903534018185960, -2.903534018185960, -2.903534018185960]),
StyblinskiTang.evaluate(&[
-2.903_534_018_185_96,
-2.903_534_018_185_96,
-2.903_534_018_185_96
]),
-117.49849711131424
);
}
Expand Down
15 changes: 7 additions & 8 deletions src/batch_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use kurobako_core::trial::{Params, Values};
use kurobako_core::{ErrorKind, Result};
use serde::Deserialize;
use serde::Serialize;
use std::io;
use structopt::StructOpt;
use serde_json::Error;
use std::io;
use std::io::Write;
use structopt::StructOpt;

/// Options of the `kurobako batch-evaluate` command.
#[derive(Debug, Clone, StructOpt)]
Expand All @@ -31,7 +31,7 @@ pub struct BatchEvaluateOpt {
#[derive(Debug, Clone, Deserialize)]
struct EvalCall {
params: Params,
step: Option<u64>
step: Option<u64>,
}

#[derive(Debug, Clone, Serialize)]
Expand All @@ -50,7 +50,7 @@ impl BatchEvaluateOpt {

let problem = track!(problem_factory.create_problem(rng))?;
let mut writer = io::stdout();
loop{
loop {
let mut line = String::new();
let n = io::stdin().read_line(&mut line)?;
if n == 0 {
Expand All @@ -64,15 +64,14 @@ impl BatchEvaluateOpt {
ErrorKind::InvalidInput
);


let evaluator_or_error = track!(problem.create_evaluator(params.clone()));

let values = match evaluator_or_error {
Ok(mut evaluator) => {
let s = step.unwrap_or_else(|| problem_spec.steps.last());
let (_, values) = track!(evaluator.evaluate(s))?;
values
},
}
Err(e) => {
if *e.kind() != ErrorKind::UnevaluableParams {
return Err(e);
Expand All @@ -82,8 +81,8 @@ impl BatchEvaluateOpt {
}
};

serde_json::to_writer(&mut writer, &EvalReply{values}).map_err(Error::from)?;
writer.write("\n".as_bytes())?;
serde_json::to_writer(&mut writer, &EvalReply { values }).map_err(Error::from)?;
writer.write_all("\n".as_bytes())?;
writer.flush()?;
}
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl NasbenchOpt {
);

let file = track!(
std::fs::File::open(&tfrecord_format_dataset_path).map_err(Error::from);
std::fs::File::open(tfrecord_format_dataset_path).map_err(Error::from);
tfrecord_format_dataset_path
)?;
let nasbench = track!(nasbench::NasBench::from_tfrecord_reader(
Expand Down
8 changes: 4 additions & 4 deletions src/dataset/surrogate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ impl SurrogateOpt {
let mut table = TableBuilder::new();
let column_types = trials[0]
.distributions
.iter()
.map(|(_, d)| {
.values()
.map(|d| {
if matches!(d, Distribution::CategoricalDistribution { .. }) {
ColumnType::Categorical
} else {
Expand Down Expand Up @@ -208,11 +208,11 @@ impl SurrogateOpt {
track!(std::fs::create_dir_all(&dir).map_err(Error::from))?;

let spec_path = dir.join("spec.json");
let spec_file = track!(std::fs::File::create(&spec_path).map_err(Error::from))?;
let spec_file = track!(std::fs::File::create(spec_path).map_err(Error::from))?;
serde_json::to_writer(spec_file, &spec)?;

let regressor_path = dir.join("model.bin");
let regressor_file = track!(std::fs::File::create(&regressor_path).map_err(Error::from))?;
let regressor_file = track!(std::fs::File::create(regressor_path).map_err(Error::from))?;
model.regressor.serialize(BufWriter::new(regressor_file))?;

eprintln!("Saved the surrogate model to the direcotry {:?}", dir);
Expand Down
2 changes: 1 addition & 1 deletion src/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl PlotOpt {

fn execute_gnuplot(script: &str) -> Result<()> {
let output = track!(Command::new("gnuplot")
.args(&["-e", script])
.args(["-e", script])
.output()
.map_err(Error::from))?;
if !output.status.success() {
Expand Down
Loading

0 comments on commit ba468f8

Please sign in to comment.