From be62891ce7730345f848bdbf0d10a9e98818def0 Mon Sep 17 00:00:00 2001 From: Evan Helbig Date: Wed, 31 May 2023 21:20:09 -0400 Subject: [PATCH] Adding image endpoint and messing around with terminal output affects --- Cargo.lock | 124 ++++++++++++---------- openai-api/Cargo.toml | 1 + openai-api/src/error.rs | 6 ++ openai-api/src/lib.rs | 31 ++++++ openai-api/src/model/create_image.rs | 95 +++++++++++++++++ openai-api/src/model/mod.rs | 1 + openai-cli/Cargo.toml | 1 + openai-cli/src/main.rs | 2 + openai-cli/src/presentation/chat.rs | 12 ++- openai-cli/src/presentation/completion.rs | 6 +- openai-cli/src/presentation/image.rs | 54 ++++++++++ openai-cli/src/presentation/mod.rs | 1 + 12 files changed, 274 insertions(+), 60 deletions(-) create mode 100644 openai-api/src/model/create_image.rs create mode 100644 openai-cli/src/presentation/image.rs diff --git a/Cargo.lock b/Cargo.lock index 2b77e48..c0c1525 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -43,7 +49,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", ] [[package]] @@ -65,9 +71,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f1e31e207a6b8fb791a38ea3105e6cb541f55e4d029902d3039a4ad07cc4105" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "bitflags" @@ -77,9 +83,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bumpalo" -version = "3.12.2" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c6ed94e98ecff0c12dd1b04c15ec0d7d9458ca8fe806cea6f12954efe74c63b" +checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" [[package]] name = "bytes" @@ -101,13 +107,13 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.24" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" +checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" dependencies = [ + "android-tzdata", "iana-time-zone", "js-sys", - "num-integer", "num-traits", "serde", "time", @@ -130,6 +136,19 @@ dependencies = [ "vec_map", ] +[[package]] +name = "console" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.45.0", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -146,6 +165,12 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "encoding_rs" version = "0.8.32" @@ -456,9 +481,9 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi 0.3.1", "libc", @@ -528,12 +553,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.17" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "518ef76f2f87365916b142844c16d8fefd85039bc5699050210a7778ee1cd1de" [[package]] name = "memchr" @@ -549,14 +571,13 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mio" -version = "0.8.6" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", - "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -577,16 +598,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "num-integer" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" -dependencies = [ - "autocfg", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.15" @@ -608,9 +619,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "9670a07f94779e00908f3e686eab508878ebb390ba6e604d3a284c00e8d0487b" [[package]] name = "openai-api" @@ -622,6 +633,7 @@ dependencies = [ "serde", "serde_json", "thiserror", + "url", ] [[package]] @@ -630,6 +642,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "console", "env_logger", "log", "openai-api", @@ -640,9 +653,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.52" +version = "0.10.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56" +checksum = "12df40a956736488b7b44fe79fe12d4f245bb5b3f5a1f6095e499760015be392" dependencies = [ "bitflags", "cfg-if", @@ -661,7 +674,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", ] [[package]] @@ -672,9 +685,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.87" +version = "0.9.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e17f59264b2809d77ae94f0e1ebabc434773f370d6ca667bd223ea10e06cc7e" +checksum = "c2ce0f250f34a308dcfdbb351f511359857d4ed2134ba715a4eadd46e1ffd617" dependencies = [ "cc", "libc", @@ -755,18 +768,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.58" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa1fb82fc0c281dd9671101b66b771ebbe1eaf967b96ac8740dcba4b70005ca8" +checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4f29d145265ec1c483c7c654450edde0bfe043d3938d6972630663356d9500" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" dependencies = [ "proc-macro2", ] @@ -791,9 +804,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.1" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370" +checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390" dependencies = [ "aho-corasick", "memchr", @@ -802,9 +815,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" +checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" [[package]] name = "reqwest" @@ -918,7 +931,7 @@ checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", ] [[package]] @@ -1021,9 +1034,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6f671d4b5ffdb8eadec19c0ae67fe2639df8684bd7bc4b83d986b8db549cf01" +checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" dependencies = [ "proc-macro2", "quote", @@ -1078,7 +1091,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", ] [[package]] @@ -1109,9 +1122,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.1" +version = "1.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105" +checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" dependencies = [ "autocfg", "bytes", @@ -1134,7 +1147,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", ] [[package]] @@ -1201,9 +1214,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" [[package]] name = "unicode-normalization" @@ -1235,6 +1248,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -1298,7 +1312,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", "wasm-bindgen-shared", ] @@ -1332,7 +1346,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.18", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/openai-api/Cargo.toml b/openai-api/Cargo.toml index 3436c16..8e14255 100644 --- a/openai-api/Cargo.toml +++ b/openai-api/Cargo.toml @@ -12,3 +12,4 @@ reqwest = { version = "0.11.18", features = ["json"] } serde = { version = "1.0.163", features = ["derive"] } serde_json = "1.0.96" thiserror = "1.0.40" +url = { version = "2.3.1", features = ["serde"] } diff --git a/openai-api/src/error.rs b/openai-api/src/error.rs index edfb487..826e6d5 100644 --- a/openai-api/src/error.rs +++ b/openai-api/src/error.rs @@ -14,6 +14,12 @@ pub enum Error { #[error("Unsupported role: {0}")] UnsupportedRole(String), + #[error("Unsupported image size: {0}")] + UnsupportedImageSize(String), + + #[error("Unsupported response format: {0}")] + UnsupportedResponseFormat(String), + #[error("Json Serialization: {0}")] JsonSerialization(String), diff --git a/openai-api/src/lib.rs b/openai-api/src/lib.rs index caefd98..0b9fe4d 100644 --- a/openai-api/src/lib.rs +++ b/openai-api/src/lib.rs @@ -15,6 +15,10 @@ pub trait Datasource { &self, request: &model::create_chat::Request, ) -> Result; + async fn create_image( + &self, + request: &model::create_image::Request, + ) -> Result; async fn list_files(&self) -> Result; } @@ -104,6 +108,33 @@ impl Datasource for OpenAIApi { } } + async fn create_image( + &self, + request: &model::create_image::Request, + ) -> Result { + let body = serde_json::to_string(&request)?; + + println!("{}", &body); + + let response = self + .http_client + .post(format!("{}/v1/images/generations", &self.base_url)) + .header("Content-Type", "application/json") + .bearer_auth(&self.api_key) + .body(body) + .send() + .await?; + + match response.error_for_status() { + Ok(response) => { + let data: model::create_image::Response = response.json().await?; + + Ok(data) + } + Err(error) => Err(error::Error::InvalidHttpResponse(error.to_string())), + } + } + async fn list_files(&self) -> Result { let response = self .http_client diff --git a/openai-api/src/model/create_image.rs b/openai-api/src/model/create_image.rs new file mode 100644 index 0000000..149eef9 --- /dev/null +++ b/openai-api/src/model/create_image.rs @@ -0,0 +1,95 @@ +use crate::error; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; +use url; + +#[derive(Debug, Serialize)] +pub struct Request { + pub prompt: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +impl Request { + pub fn new(prompt: String) -> Self { + Self { + prompt, + n: None, + size: None, + response_format: None, + user: None, + } + } + + pub fn response_format(mut self, response_format: ResponseFormat) -> Self { + self.response_format = Some(response_format); + self + } +} + +#[derive(Debug, Serialize)] +pub enum Size { + #[serde(rename = "256x256")] + _256x256_, + + #[serde(rename = "512x512")] + _512x512_, + + #[serde(rename = "1024x1024")] + _1024x1024_, +} + +impl FromStr for Size { + type Err = error::Error; + + fn from_str(s: &str) -> Result { + match s { + "256x256" => Ok(Self::_256x256_), + "512x512" => Ok(Self::_512x512_), + "1024x1024" => Ok(Self::_1024x1024_), + _ => Err(Self::Err::UnsupportedImageSize(s.to_string())), + } + } +} + +#[derive(Clone, Debug, Serialize)] +pub enum ResponseFormat { + Url, + + #[serde(rename = "b64_json")] + B64Json, +} + +impl FromStr for ResponseFormat { + type Err = error::Error; + + fn from_str(s: &str) -> Result { + match s { + "url" => Ok(Self::Url), + "b64json" => Ok(Self::B64Json), + _ => Err(error::Error::UnsupportedResponseFormat(s.to_string())), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Response { + pub created: usize, + pub data: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct ImageUrl { + pub url: Option, + pub b64_json: Option, +} diff --git a/openai-api/src/model/mod.rs b/openai-api/src/model/mod.rs index f2de0a9..fd1c876 100644 --- a/openai-api/src/model/mod.rs +++ b/openai-api/src/model/mod.rs @@ -1,5 +1,6 @@ pub mod create_chat; pub mod create_completion; +pub mod create_image; pub mod list_files; pub mod list_models; pub mod model; diff --git a/openai-cli/Cargo.toml b/openai-cli/Cargo.toml index bf25e79..85520c3 100644 --- a/openai-cli/Cargo.toml +++ b/openai-cli/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] anyhow = "1.0.71" async-trait = "0.1.68" +console = "0.15.7" env_logger = "0.10.0" log = "0.4.17" openai-api = { path = "../openai-api" } diff --git a/openai-cli/src/main.rs b/openai-cli/src/main.rs index ce89053..9680dfe 100644 --- a/openai-cli/src/main.rs +++ b/openai-cli/src/main.rs @@ -22,6 +22,7 @@ enum Subcommand { Model(presentation::model::Opt), Completion(presentation::completion::Opt), Chat(presentation::chat::Opt), + Image(presentation::image::Opt), File(presentation::file::Opt), } @@ -48,6 +49,7 @@ async fn main() -> Result<(), Error> { Subcommand::Model(opt) => opt.run(http_client, api_key).await?, Subcommand::Completion(opt) => opt.run(http_client, api_key).await?, Subcommand::Chat(opt) => opt.run(http_client, api_key).await?, + Subcommand::Image(opt) => opt.run(http_client, api_key).await?, Subcommand::File(opt) => opt.run(http_client, api_key).await?, } diff --git a/openai-cli/src/presentation/chat.rs b/openai-cli/src/presentation/chat.rs index ac69ec3..8a8ef27 100644 --- a/openai-cli/src/presentation/chat.rs +++ b/openai-cli/src/presentation/chat.rs @@ -1,5 +1,6 @@ use anyhow::Error; use async_trait::async_trait; +use console; use openai_api::Datasource; use std::{io, sync}; use structopt::StructOpt; @@ -51,7 +52,13 @@ impl Command for Opt { } }; - println!("What can I assist you with?"); + let assistant_response = console::Style::new().blue(); + let assistant = console::Style::new().dim(); + + println!( + "{}", + assistant_response.apply_to("What can I assist you with?") + ); loop { let mut content = String::new(); @@ -72,7 +79,8 @@ impl Command for Opt { println!( "{:#?}: {:#?}", - response.choices[0].message.role, response.choices[0].message.content + assistant.apply_to(&response.choices[0].message.role), + assistant_response.apply_to(&response.choices[0].message.content) ); request.messages.push(response.choices[0].message.clone()); diff --git a/openai-cli/src/presentation/completion.rs b/openai-cli/src/presentation/completion.rs index 01f75f2..0d916a5 100644 --- a/openai-cli/src/presentation/completion.rs +++ b/openai-cli/src/presentation/completion.rs @@ -27,10 +27,10 @@ pub struct Create { #[structopt(long, short)] pub suffix: Option, - #[structopt(long, default_value="100")] + #[structopt(long, default_value = "100")] pub max_tokens: usize, - #[structopt(long, short, default_value="0.0")] + #[structopt(long, short, default_value = "0.0")] pub temperature: f32, } @@ -53,7 +53,7 @@ impl Command for Opt { let response = datasource.create_completion(&request).await?; - println!("{:#?}", response); + println!("{}", response.choices[0].text); Ok(()) } diff --git a/openai-cli/src/presentation/image.rs b/openai-cli/src/presentation/image.rs new file mode 100644 index 0000000..c299a36 --- /dev/null +++ b/openai-cli/src/presentation/image.rs @@ -0,0 +1,54 @@ +use anyhow::Error; +use async_trait::async_trait; +use openai_api::Datasource; +use std::sync; +use structopt::StructOpt; + +use super::command::Command; + +#[derive(StructOpt)] +pub struct Opt { + #[structopt(subcommand)] + pub subcommand: Subcommand, +} + +#[derive(StructOpt)] +pub enum Subcommand { + Create(Create), +} + +#[derive(StructOpt)] +pub struct Create { + pub prompt: String, + + #[structopt(short, long = "number")] + pub n: Option, + + #[structopt(short, long)] + pub size: Option, + + #[structopt(short, long)] + pub response_format: Option, +} + +#[async_trait] +impl Command for Opt { + async fn run( + &self, + http_client: sync::Arc, + api_key: String, + ) -> Result<(), Error> { + let datasource = openai_api::OpenAIApi::new(http_client, api_key); + let request = match &self.subcommand { + Subcommand::Create(opt) => { + openai_api::model::create_image::Request::new(opt.prompt.clone()) + } + }; + + let response = datasource.create_image(&request).await?; + + println!("{}", response.data[0].url.as_ref().unwrap()); + + Ok(()) + } +} diff --git a/openai-cli/src/presentation/mod.rs b/openai-cli/src/presentation/mod.rs index 072dde4..20ac586 100644 --- a/openai-cli/src/presentation/mod.rs +++ b/openai-cli/src/presentation/mod.rs @@ -2,4 +2,5 @@ pub mod chat; pub mod command; pub mod completion; pub mod file; +pub mod image; pub mod model;