Skip to content

Commit

Permalink
Merge branch 'main' of gh:CQCL-DEV/hugr into doug/fix-non-root-profile
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Dec 11, 2024
2 parents c453cd3 + 97ba7f4 commit 750bc0e
Show file tree
Hide file tree
Showing 133 changed files with 6,628 additions and 2,505 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ serde_yaml = "0.9.19"
smol_str = "0.3.1"
strum = "0.26.1"
strum_macros = "0.26.1"
thiserror = "1.0.28"
thiserror = "2.0.6"
typetag = "0.2.7"
urlencoding = "2.1.2"
webbrowser = "1.0.0"
Expand Down
4 changes: 2 additions & 2 deletions hugr-cli/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ impl ExtArgs {
pub fn run_dump(&self, registry: &ExtensionRegistry) {
let base_dir = &self.outdir;

for (name, ext) in registry.iter() {
for ext in registry.iter() {
let mut path = base_dir.clone();
for part in name.split('.') {
for part in ext.name().split('.') {
path.push(part);
}
path.set_extension("json");
Expand Down
52 changes: 37 additions & 15 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use derive_more::{Display, Error, From};
use hugr::extension::ExtensionRegistry;
use hugr::package::PackageValidationError;
use hugr::package::{PackageEncodingError, PackageValidationError};
use hugr::Hugr;
use std::io::{Cursor, Read, Seek, SeekFrom};
use std::{ffi::OsString, path::PathBuf};

pub mod extensions;
Expand Down Expand Up @@ -46,6 +47,9 @@ pub enum CliError {
/// Error parsing input.
#[display("Error parsing package: {_0}")]
Parse(serde_json::Error),
/// Hugr load error.
#[display("Error parsing package: {_0}")]
HUGRLoad(PackageEncodingError),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(PackageValidationError),
Expand Down Expand Up @@ -96,15 +100,10 @@ impl PackageOrHugr {
}

/// Validates the package or hugr.
///
/// Updates the extension registry with any new extensions defined in the package.
pub fn update_validate(
&mut self,
reg: &mut ExtensionRegistry,
) -> Result<(), PackageValidationError> {
pub fn validate(&self) -> Result<(), PackageValidationError> {
match self {
PackageOrHugr::Package(pkg) => pkg.update_validate(reg),
PackageOrHugr::Hugr(hugr) => hugr.update_validate(reg).map_err(Into::into),
PackageOrHugr::Package(pkg) => pkg.validate(),
PackageOrHugr::Hugr(hugr) => Ok(hugr.validate()?),
}
}
}
Expand All @@ -120,13 +119,21 @@ impl AsRef<[Hugr]> for PackageOrHugr {

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package_or_hugr(&mut self) -> Result<PackageOrHugr, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
if let Ok(hugr) = serde_json::from_value::<Hugr>(val.clone()) {
return Ok(PackageOrHugr::Hugr(hugr));
pub fn get_package_or_hugr(
&mut self,
extensions: &ExtensionRegistry,
) -> Result<PackageOrHugr, CliError> {
// We need to read the input twice; once to try to load it as a HUGR, and if that fails, as a package.
// If `input` is a file, we can reuse the reader by seeking back to the start.
// Else, we need to read the file into a buffer.
match self.input.can_seek() {
true => get_package_or_hugr_seek(&mut self.input, extensions),
false => {
let mut buffer = Vec::new();
self.input.read_to_end(&mut buffer)?;
get_package_or_hugr_seek(Cursor::new(buffer), extensions)
}
}
let pkg = serde_json::from_value::<Package>(val.clone())?;
Ok(PackageOrHugr::Package(pkg))
}

/// Read either a package from the input.
Expand All @@ -142,3 +149,18 @@ impl HugrArgs {
Ok(pkg)
}
}

/// Load a package or hugr from a seekable input.
fn get_package_or_hugr_seek<I: Seek + Read>(
mut input: I,
extensions: &ExtensionRegistry,
) -> Result<PackageOrHugr, CliError> {
match Hugr::load_json(&mut input, extensions) {
Ok(hugr) => Ok(PackageOrHugr::Hugr(hugr)),
Err(_) => {
input.seek(SeekFrom::Start(0))?;
let pkg = Package::from_json_reader(input, extensions)?;
Ok(PackageOrHugr::Package(pkg))
}
}
}
16 changes: 14 additions & 2 deletions hugr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
use clap::Parser as _;

use hugr_cli::{validate, CliArgs};
use hugr_cli::{mermaid, validate, CliArgs};

use clap_verbosity_flag::log::Level;

fn main() {
match CliArgs::parse() {
CliArgs::Validate(args) => run_validate(args),
CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG),
CliArgs::Mermaid(mut args) => args.run_print().unwrap(),
CliArgs::Mermaid(args) => run_mermaid(args),
CliArgs::External(_) => {
// TODO: Implement support for external commands.
// Running `hugr COMMAND` would look for `hugr-COMMAND` in the path
Expand All @@ -36,3 +36,15 @@ fn run_validate(mut args: validate::ValArgs) {
std::process::exit(1);
}
}

/// Run the `mermaid` subcommand.
fn run_mermaid(mut args: mermaid::MermaidArgs) {
let result = args.run_print();

if let Err(e) = result {
if args.hugr_args.verbosity(Level::Error) {
eprintln!("{}", e);
}
std::process::exit(1);
}
}
7 changes: 5 additions & 2 deletions hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ impl MermaidArgs {
/// Write the mermaid diagram to the output.
pub fn run_print(&mut self) -> Result<(), crate::CliError> {
let hugrs = if self.validate {
self.hugr_args.validate()?.0
self.hugr_args.validate()?
} else {
self.hugr_args.get_package_or_hugr()?.into_hugrs()
let extensions = self.hugr_args.extensions()?;
self.hugr_args
.get_package_or_hugr(&extensions)?
.into_hugrs()
};

for hugr in hugrs {
Expand Down
27 changes: 12 additions & 15 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@ use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs};

// TODO: Deprecated re-export. Remove on a breaking release.
#[doc(inline)]
#[deprecated(
since = "0.13.2",
note = "Use `hugr::package::PackageValidationError` instead."
)]
pub use hugr::package::PackageValidationError as ValError;

/// Validate and visualise a HUGR file.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
Expand All @@ -31,7 +23,7 @@ pub const VALID_PRINT: &str = "HUGR valid!";

impl ValArgs {
/// Run the HUGR cli and validate against an extension registry.
pub fn run(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
pub fn run(&mut self) -> Result<Vec<Hugr>, CliError> {
let result = self.hugr_args.validate()?;
if self.verbosity(Level::Info) {
eprintln!("{}", VALID_PRINT);
Expand All @@ -50,24 +42,29 @@ impl HugrArgs {
///
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let mut package = self.get_package_or_hugr()?;
pub fn validate(&mut self) -> Result<Vec<Hugr>, CliError> {
let reg = self.extensions()?;
let package = self.get_package_or_hugr(&reg)?;

package.validate()?;
Ok(package.into_hugrs())
}

let mut reg: ExtensionRegistry = if self.no_std {
/// Return a register with the selected extensions.
pub fn extensions(&self) -> Result<ExtensionRegistry, CliError> {
let mut reg = if self.no_std {
hugr::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr::std_extensions::STD_REG.to_owned()
};

// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext);
}

package.update_validate(&mut reg)?;
Ok((package.into_hugrs(), reg))
Ok(reg)
}

/// Test whether a `level` message should be output.
Expand Down
29 changes: 13 additions & 16 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
//! calling the CLI binary, which Miri doesn't support.
#![cfg(all(test, not(miri)))]

use std::sync::Arc;

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
extension::prelude::{BOOL_T, QB_T},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
extension::prelude::{bool_t, qb_t},
std_extensions::arithmetic::float_types::float64_type,
types::Signature,
Hugr,
};
Expand All @@ -41,7 +38,7 @@ const FLOAT_EXT_FILE: &str = concat!(

/// A test package, containing a module-rooted HUGR.
#[fixture]
fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
fn test_package(#[default(bool_t())] id_type: Type) -> Package {
let mut module = ModuleBuilder::new();
let df = module
.define_function("test", Signature::new_endo(id_type))
Expand All @@ -50,14 +47,12 @@ fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
df.finish_with_outputs([i]).unwrap();
let hugr = module.hugr().clone(); // unvalidated

let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: Arc<hugr::Extension> = serde_json::from_reader(rdr).unwrap();
Package::new(vec![hugr], vec![float_ext]).unwrap()
Package::new(vec![hugr]).unwrap()
}

/// A DFG-rooted HUGR.
#[fixture]
fn test_hugr(#[default(BOOL_T)] id_type: Type) -> Hugr {
fn test_hugr(#[default(bool_t())] id_type: Type) -> Hugr {
let mut df = DFGBuilder::new(Signature::new_endo(id_type)).unwrap();
let [i] = df.input_wires_arr();
df.set_outputs([i]).unwrap();
Expand Down Expand Up @@ -120,7 +115,7 @@ fn test_mermaid(test_hugr_file: NamedTempFile, mut cmd: Command) {

#[fixture]
fn bad_hugr_string() -> String {
let df = DFGBuilder::new(Signature::new_endo(type_row![QB_T])).unwrap();
let df = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
let bad_hugr = df.hugr().clone();

serde_json::to_string(&bad_hugr).unwrap()
Expand All @@ -131,7 +126,9 @@ fn test_mermaid_invalid(bad_hugr_string: String, mut cmd: Command) {
cmd.arg("mermaid");
cmd.arg("--validate");
cmd.write_stdin(bad_hugr_string);
cmd.assert().failure().stderr(contains("UnconnectedPort"));
cmd.assert()
.failure()
.stderr(contains("has an unconnected port"));
}

#[rstest]
Expand All @@ -142,7 +139,7 @@ fn test_bad_hugr(bad_hugr_string: String, mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error validating HUGR").and(contains("unconnected port")));
.stderr(contains("Node(1)").and(contains("unconnected port")));
}

#[rstest]
Expand Down Expand Up @@ -178,7 +175,7 @@ fn test_no_std(test_hugr_string: String, mut val_cmd: Command) {
}

#[fixture]
fn float_hugr_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
fn float_hugr_string(#[with(float64_type())] test_hugr: Hugr) -> String {
serde_json::to_string(&test_hugr).unwrap()
}

Expand All @@ -191,7 +188,7 @@ fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains(" Extension 'arithmetic.float.types' not found"));
.stderr(contains(" requires extension arithmetic.float.types"));
}

#[rstest]
Expand All @@ -205,7 +202,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
val_cmd.assert().success().stderr(contains(VALID_PRINT));
}
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_package: Package) -> String {
fn package_string(#[with(float64_type())] test_package: Package) -> String {
serde_json::to_string(&test_package).unwrap()
}

Expand Down
Loading

0 comments on commit 750bc0e

Please sign in to comment.