diff --git a/Cargo.toml b/Cargo.toml index 71aefc3..7fcc550 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,13 @@ ureq = { version = "2.8.0", optional = true, features = [ "json", "socks-proxy", ] } +sha2 = "0.10" +base64 = "0.22.1" +hex = "0.4.3" +regex = "1.11.1" +lazy_static = "1.5.0" +sha1 = "0.10.6" +urlencoding = "2.1.3" [features] default = ["default-tls", "tokio", "ureq"] @@ -68,6 +75,8 @@ ureq = [ ] [dev-dependencies] +env_logger = "0.11.5" hex-literal = "0.4.1" +rand = "0.8.5" sha2 = "0.10" tokio-test = "0.4.2" diff --git a/README.md b/README.md index 3172115..4a5d86e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +## temporary todo + +to run the test, do `RUST_LOG=trace HF_TOKEN=token_here HF_REPO=repo_here cargo run --example upload --release` + +## real readme here + This crates aims to emulate and be compatible with the [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package. @@ -18,14 +24,14 @@ However allowing new features or creating new features might be denied by lack o time. We're focusing on what we currently internally need. Hopefully that subset is already interesting to more users. - -# How to use +# How to use Add the dependency ```bash cargo add hf-hub # --features tokio ``` + `tokio` feature will enable an async (and potentially faster) API. Use the crate: diff --git a/examples/upload.rs b/examples/upload.rs new file mode 100644 index 0000000..6e44bae --- /dev/null +++ b/examples/upload.rs @@ -0,0 +1,83 @@ +use std::time::Instant; + +use hf_hub::{ + api::tokio::{ApiBuilder, ApiError}, + Repo, +}; +use rand::Rng; + +const ONE_MB: usize = 1024 * 1024; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + let token = std::env::var("HF_TOKEN") + .map_err(|_| "HF_TOKEN environment variable not set".to_string())?; + let hf_repo = std::env::var("HF_REPO") + .map_err(|_| "HF_REPO environment variable not set, e.g. apyh/gronk".to_string())?; + + let api = ApiBuilder::new().with_token(Some(token)).build()?; + let repo = Repo::model(hf_repo); + let api_repo = api.repo(repo); + + let exists = api_repo.exists().await; + if !exists { + return Err(ApiError::GatedRepoError("repo does not exist".to_string()).into()); + } else { + log::info!("repo exists!"); + } + + let is_writable = api_repo.is_writable().await; + if !is_writable { + return Err(ApiError::GatedRepoError("repo is not writable".to_string()).into()); + } else { + log::info!("repo is writable!"); + } + + let files = [ + ( + format!("im a tiny file {:?}", Instant::now()) + .as_bytes() + .to_vec(), + "tiny_file.txt", + ), + ( + { + let mut data = vec![0u8; ONE_MB]; + rand::thread_rng().fill(&mut data[..]); + data + }, + "1m_file.txt", + ), + ( + { + let mut data = vec![0u8; 10 * ONE_MB]; + rand::thread_rng().fill(&mut data[..]); + data + }, + "10m_file.txt", + ), + ( + { + let mut data = vec![0u8; 20 * ONE_MB]; + rand::thread_rng().fill(&mut data[..]); + data + }, + "20m_file.txt", + ), + ]; + let res = api_repo + .upload_files( + files + .into_iter() + .map(|(data, path)| (data.into(), path.into())) + .collect(), + None, + "update multiple files!".to_string().into(), + false, + ) + .await?; + log::info!("commit result: {:?}", res); + log::info!("Success!!"); + Ok(()) +} diff --git a/flake.nix b/flake.nix index 2ef5f7b..926dab1 100644 --- a/flake.nix +++ b/flake.nix @@ -8,17 +8,15 @@ url = "github:oxalica/rust-overlay"; }; }; - outputs = - { - self, - crate2nix, - nixpkgs, - flake-utils, - rust-overlay, - }: + outputs = { + self, + crate2nix, + nixpkgs, + flake-utils, + rust-overlay, + }: flake-utils.lib.eachDefaultSystem ( - system: - let + system: let cargoNix = crate2nix.tools.${system}.appliedCargoNix { name = "hf-hub"; src = ./.; @@ -30,10 +28,8 @@ ]; }; hf-hub = cargoNix.rootCrate.build; - in - { + in { devShells = with pkgs; rec { - default = pure; pure = mkShell { @@ -43,19 +39,16 @@ }; impure = mkShell { - buildInputs = - [ - openssl.dev - pkg-config - (rust-bin.stable.latest.default.override { - extensions = [ - "rust-analyzer" - "rust-src" - ]; - }) - ]; - - inputsFrom = [ ]; + buildInputs = [ + openssl.dev + pkg-config + (rust-bin.stable.latest.default.override { + extensions = [ + "rust-analyzer" + "rust-src" + ]; + }) + ]; postShellHook = '' export PATH=$PATH:~/.cargo/bin diff --git a/src/api/tokio.rs b/src/api/tokio/download.rs similarity index 59% rename from src/api/tokio.rs rename to src/api/tokio/download.rs index b616fa0..3de0ba9 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio/download.rs @@ -1,428 +1,45 @@ -use super::RepoInfo; -use crate::{Cache, Repo, RepoType}; +use super::{exponential_backoff, symlink_or_rename, ApiError, ApiRepo, RepoInfo}; use indicatif::{ProgressBar, ProgressStyle}; -use rand::Rng; -use reqwest::{ - header::{ - HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, - CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, - }, - redirect::Policy, - Client, Error as ReqwestError, RequestBuilder, -}; -use std::num::ParseIntError; -use std::path::{Component, Path, PathBuf}; +use reqwest::{header::RANGE, RequestBuilder}; +use std::path::PathBuf; use std::sync::Arc; -use thiserror::Error; use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; -use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; - -/// Current version (used in user-agent) -const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) -const NAME: &str = env!("CARGO_PKG_NAME"); - -#[derive(Debug, Error)] -/// All errors the API can throw -pub enum ApiError { - /// Api expects certain header to be present in the results to derive some information - #[error("Header {0} is missing")] - MissingHeader(HeaderName), - - /// The header exists, but the value is not conform to what the Api expects. - #[error("Header {0} is invalid")] - InvalidHeader(HeaderName), - - /// The value cannot be used as a header during request header construction - #[error("Invalid header value {0}")] - InvalidHeaderValue(#[from] InvalidHeaderValue), - - /// The header value is not valid utf-8 - #[error("header value is not a string")] - ToStr(#[from] ToStrError), - - /// Error in the request - #[error("request error: {0}")] - RequestError(#[from] ReqwestError), - - /// Error parsing some range value - #[error("Cannot parse int")] - ParseIntError(#[from] ParseIntError), - - /// I/O Error - #[error("I/O error {0}")] - IoError(#[from] std::io::Error), - - /// We tried to download chunk too many times - #[error("Too many retries: {0}")] - TooManyRetries(Box), - - /// Semaphore cannot be acquired - #[error("Try acquire: {0}")] - TryAcquireError(#[from] TryAcquireError), - - /// Semaphore cannot be acquired - #[error("Acquire: {0}")] - AcquireError(#[from] AcquireError), - // /// Semaphore cannot be acquired - // #[error("Invalid Response: {0:?}")] - // InvalidResponse(Response), -} - -/// Helper to create [`Api`] with all the options. -#[derive(Debug)] -pub struct ApiBuilder { - endpoint: String, - cache: Cache, - url_template: String, - token: Option, - max_files: usize, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -impl Default for ApiBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ApiBuilder { - /// Default api builder - /// ``` - /// use hf_hub::api::tokio::ApiBuilder; - /// let api = ApiBuilder::new().build().unwrap(); - /// ``` - pub fn new() -> Self { - let cache = Cache::default(); - Self::from_cache(cache) - } - - /// From a given cache - /// ``` - /// use hf_hub::{api::tokio::ApiBuilder, Cache}; - /// let path = std::path::PathBuf::from("/tmp"); - /// let cache = Cache::new(path); - /// let api = ApiBuilder::from_cache(cache).build().unwrap(); - /// ``` - pub fn from_cache(cache: Cache) -> Self { - let token = cache.token(); - - let progress = true; - - Self { - endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache, - token, - max_files: num_cpus::get(), - chunk_size: 10_000_000, - parallel_failures: 0, - max_retries: 0, - progress, - } - } - - /// Wether to show a progressbar - pub fn with_progress(mut self, progress: bool) -> Self { - self.progress = progress; - self - } - - /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache = Cache::new(cache_dir); - self - } - - /// Sets the token to be used in the API - pub fn with_token(mut self, token: Option) -> Self { - self.token = token; - self - } - - fn build_headers(&self) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); - headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); - if let Some(token) = &self.token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}"))?, - ); - } - Ok(headers) - } - - /// Consumes the builder and builds the final [`Api`] - pub fn build(self) -> Result { - let headers = self.build_headers()?; - let client = Client::builder().default_headers(headers.clone()).build()?; - - // Policy: only follow relative redirects - // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 - let relative_redirect_policy = Policy::custom(|attempt| { - // Follow redirects up to a maximum of 10. - if attempt.previous().len() > 10 { - return attempt.error("too many redirects"); - } - - if let Some(last) = attempt.previous().last() { - // If the url is not relative - if last.make_relative(attempt.url()).is_none() { - return attempt.stop(); - } - } - - // Follow redirect - attempt.follow() - }); - - let relative_redirect_client = Client::builder() - .redirect(relative_redirect_policy) - .default_headers(headers) - .build()?; - Ok(Api { - endpoint: self.endpoint, - url_template: self.url_template, - cache: self.cache, - client, - relative_redirect_client, - max_files: self.max_files, - chunk_size: self.chunk_size, - parallel_failures: self.parallel_failures, - max_retries: self.max_retries, - progress: self.progress, - }) - } -} - -#[derive(Debug)] -struct Metadata { - commit_hash: String, - etag: String, - size: usize, -} - -/// The actual Api used to interact with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] -#[derive(Clone, Debug)] -pub struct Api { - endpoint: String, - url_template: String, - cache: Cache, - client: Client, - relative_redirect_client: Client, - max_files: usize, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -fn make_relative(src: &Path, dst: &Path) -> PathBuf { - let path = src; - let base = dst; - - assert_eq!( - path.is_absolute(), - base.is_absolute(), - "This function is made to look at absolute paths only" - ); - let mut ita = path.components(); - let mut itb = base.components(); - - loop { - match (ita.next(), itb.next()) { - (Some(a), Some(b)) if a == b => (), - (some_a, _) => { - // Ignoring b, because 1 component is the filename - // for which we don't need to go back up for relative - // filename to work. - let mut new_path = PathBuf::new(); - for _ in itb { - new_path.push(Component::ParentDir); - } - if let Some(a) = some_a { - new_path.push(a); - for comp in ita { - new_path.push(comp); - } - } - return new_path; - } - } - } -} - -fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { - if dst.exists() { - return Ok(()); - } - - let rel_src = make_relative(src, dst); - #[cfg(target_os = "windows")] - { - if std::os::windows::fs::symlink_file(rel_src, dst).is_err() { - std::fs::rename(src, dst)?; - } - } - - #[cfg(target_family = "unix")] - std::os::unix::fs::symlink(rel_src, dst)?; - - Ok(()) -} +use tokio::sync::Semaphore; -fn jitter() -> usize { - rand::thread_rng().gen_range(0..=500) -} - -fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { - (base_wait_time + n.pow(2) + jitter()).min(max) -} - -impl Api { - /// Creates a default Api, for Api options See [`ApiBuilder`] - pub fn new() -> Result { - ApiBuilder::new().build() - } - - /// Get the underlying api client - /// Allows for lower level access - pub fn client(&self) -> &Client { - &self.client - } - - async fn metadata(&self, url: &str) -> Result { - let response = self - .relative_redirect_client - .get(url) - .header(RANGE, "bytes=0-0") - .send() - .await?; - let response = response.error_for_status()?; - let headers = response.headers(); - let header_commit = HeaderName::from_static("x-repo-commit"); - let header_linked_etag = HeaderName::from_static("x-linked-etag"); - let header_etag = HeaderName::from_static("etag"); - - let etag = match headers.get(&header_linked_etag) { - Some(etag) => etag, - None => headers - .get(&header_etag) - .ok_or(ApiError::MissingHeader(header_etag))?, - }; - // Cleaning extra quotes - let etag = etag.to_str()?.to_string().replace('"', ""); - let commit_hash = headers - .get(&header_commit) - .ok_or(ApiError::MissingHeader(header_commit))? - .to_str()? - .to_string(); - - // The response was redirected o S3 most likely which will - // know about the size of the file - let response = if response.status().is_redirection() { - self.client - .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) - .header(RANGE, "bytes=0-0") - .send() - .await? - } else { - response - }; - let headers = response.headers(); - let content_range = headers - .get(CONTENT_RANGE) - .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? - .to_str()?; - - let size = content_range - .split('/') - .last() - .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? - .parse()?; - Ok(Metadata { - commit_hash, - etag, - size, - }) - } - - /// Creates a new handle [`ApiRepo`] which contains operations - /// on a particular [`Repo`] - pub fn repo(&self, repo: Repo) -> ApiRepo { - ApiRepo::new(self.clone(), repo) - } - - /// Simple wrapper over - /// ``` - /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let api = Api::new().unwrap(); - /// let api = api.repo(Repo::new(model_id, RepoType::Model)); - /// ``` - pub fn model(&self, model_id: String) -> ApiRepo { - self.repo(Repo::new(model_id, RepoType::Model)) - } - - /// Simple wrapper over - /// ``` - /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let api = Api::new().unwrap(); - /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); - /// ``` - pub fn dataset(&self, model_id: String) -> ApiRepo { - self.repo(Repo::new(model_id, RepoType::Dataset)) - } - - /// Simple wrapper over +impl ApiRepo { + /// Get the fully qualified URL of the remote filename /// ``` - /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); + /// # use hf_hub::api::tokio::Api; /// let api = Api::new().unwrap(); - /// let api = api.repo(Repo::new(model_id, RepoType::Space)); + /// let url = api.model("gpt2".to_string()).url("model.safetensors"); + /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); /// ``` - pub fn space(&self, model_id: String) -> ApiRepo { - self.repo(Repo::new(model_id, RepoType::Space)) - } -} - -/// Shorthand for accessing things within a particular repo -#[derive(Debug)] -pub struct ApiRepo { - api: Api, - repo: Repo, -} - -impl ApiRepo { - fn new(api: Api, repo: Repo) -> Self { - Self { api, repo } + pub fn file_url(&self, filename: &str) -> String { + let endpoint = &self.api.endpoint; + let revision = &self.repo.url_revision(); + format!( + "{endpoint}/{}/resolve/{revision}/{filename}", + self.repo.url() + ) + .replace("{endpoint}", endpoint) + .replace("{repo_id}", &self.repo.url()) + .replace("{revision}", revision) + .replace("{filename}", filename) } -} -impl ApiRepo { - /// Get the fully qualified URL of the remote filename + /// Get the fully qualified URL for a preupload /// ``` /// # use hf_hub::api::tokio::Api; /// let api = Api::new().unwrap(); /// let url = api.model("gpt2".to_string()).url("model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); + /// assert_eq!(url, "https://huggingface.co/api/models/gpt2/model.safetensors/preupload/main"); /// ``` - pub fn url(&self, filename: &str) -> String { + pub fn preupload_url(&self) -> String { let endpoint = &self.api.endpoint; + let repo_id = self.repo.url(); + let repo_type = self.repo.repo_type.to_string(); let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + format!("{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}") } async fn download_tempfile( @@ -546,7 +163,7 @@ impl ApiRepo { /// # }) /// ``` pub async fn download(&self, filename: &str) -> Result { - let url = self.url(filename); + let url = self.file_url(filename); let metadata = self.api.metadata(&url).await?; let cache = self.api.cache.repo(self.repo.clone()); @@ -622,9 +239,12 @@ impl ApiRepo { #[cfg(test)] mod tests { use super::*; - use crate::api::Siblings; + use crate::{ + api::{tokio::ApiBuilder, Siblings}, + Repo, RepoType, + }; use hex_literal::hex; - use rand::distributions::Alphanumeric; + use rand::{distributions::Alphanumeric, Rng}; use serde_json::{json, Value}; use sha2::{Digest, Sha256}; diff --git a/src/api/tokio/mod.rs b/src/api/tokio/mod.rs new file mode 100644 index 0000000..b371682 --- /dev/null +++ b/src/api/tokio/mod.rs @@ -0,0 +1,565 @@ +use super::RepoInfo; +use crate::{Cache, Repo, RepoType}; +use http::StatusCode; +use rand::Rng; +use regex::Regex; +use reqwest::{ + header::{ + HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, + CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, + }, + redirect::Policy, + Client, Error as ReqwestError, +}; +use std::{fmt::Display, num::ParseIntError}; +use std::{ + future::Future, + path::{Component, Path, PathBuf}, +}; +use thiserror::Error; +use tokio::sync::{AcquireError, TryAcquireError}; + +mod download; +mod repo_info; +mod upload; +pub use upload::{CommitError, UploadSource}; + +/// Current version (used in user-agent) +const VERSION: &str = env!("CARGO_PKG_VERSION"); +/// Current name (used in user-agent) +const NAME: &str = env!("CARGO_PKG_NAME"); + +/// A custom error type that combines a Reqwest error with the response body. +/// +/// This struct wraps a [`reqwest::Error`] and includes the response body as a string, +/// which can be useful for debugging and error reporting when HTTP requests fail. +#[derive(Debug)] +pub struct ReqwestErrorWithBody { + url: String, + error: ReqwestError, + body: Result, +} + +impl Display for ReqwestErrorWithBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Request error: {}", self.url)?; + writeln!(f, "{}", self.error)?; + match &self.body { + Ok(body) => { + writeln!(f, "Response body:")?; + writeln!(f, "{body}")?; + } + Err(err) => { + writeln!(f, "Failed to fetch body:")?; + writeln!(f, "{err}")?; + } + } + Ok(()) + } +} + +impl std::error::Error for ReqwestErrorWithBody {} + +// Extension trait for `reqwest::Response` that provides error handling with response body capture. +/// +/// This trait adds the ability to check for HTTP error status codes while preserving the response body +/// in case of an error, which is useful for debugging and error reporting. +/// +/// # Examples +/// +/// ``` +/// use hf_hub::api::tokio::HfBadResponse; +/// +/// async fn example() -> Result<(), Box> { +/// let response = reqwest::get("https://api.example.com/data").await?; +/// +/// // Will return Err with both the error and response body if status code is not successful +/// let response = response.maybe_err().await?; +/// +/// // Process successful response... +/// Ok(()) +/// } +/// ``` +/// +/// # Error Handling +/// +/// - If the response status is successful (2xx), returns `Ok(Response)` +/// - If the response status indicates an error (4xx, 5xx), returns `Err(ApiError)` +/// containing both the original error and the response body text +pub trait HfBadResponse { + /// Checks if the response status code indicates an error, and if so, captures the response body + /// along with the error details. + /// + /// Returns a Future that resolves to: + /// - `Ok(Response)` if the status code is successful + /// - `Err(ApiError)` if the status code indicates an error + fn maybe_hf_err(self) -> impl Future> + where + Self: Sized; +} + +lazy_static::lazy_static! { + static ref REPO_API_REGEX: Regex = Regex::new( + r#"(?x) + # staging or production endpoint + ^https://[^/]+ + ( + # on /api/repo_type/repo_id + /api/(models|datasets|spaces)/(.+) + | + # or /repo_id/resolve/revision/... + /(.+)/resolve/(.+) + ) + "#, + ).unwrap(); +} + +impl HfBadResponse for reqwest::Response { + async fn maybe_hf_err(self) -> Result + where + Self: Sized, + { + let error = self.error_for_status_ref(); + if let Err(error) = error { + let hf_error_code = self + .headers() + .get("X-Error-Code") + .and_then(|v| v.to_str().ok()); + let hf_error_message = self + .headers() + .get("X-Error-Message") + .and_then(|v| v.to_str().ok()); + let url = self.url().to_string(); + Err(match (hf_error_code, hf_error_message) { + (Some("RevisionNotFound"), _) => ApiError::RevisionNotFoundError(url), + (Some("EntryNotFound"), _) => ApiError::EntryNotFoundError(url), + (Some("GatedRepo"), _) => ApiError::GatedRepoError(url), + (_, Some("Access to this resource is disabled.")) => { + ApiError::DisabledRepoError(url) + } + // 401 is misleading as it is returned for: + // - private and gated repos if user is not authenticated + // - missing repos + // => for now, we process them as `RepoNotFound` anyway. + // See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 + (Some("RepoNotFound"), _) + if self.status() == StatusCode::UNAUTHORIZED + && REPO_API_REGEX.is_match(&url) => + { + ApiError::RepositoryNotFoundError(url) + } + (_, _) => { + let body = self.text().await; + ApiError::RequestErrorWithBody(ReqwestErrorWithBody { url, body, error }) + } + }) + } else { + Ok(self) + } + } +} + +#[derive(Debug, Error)] +/// All errors the API can throw +pub enum ApiError { + /// Api expects certain header to be present in the results to derive some information + #[error("Header {0} is missing")] + MissingHeader(HeaderName), + + /// The header exists, but the value does not conform to what the Api expects. + #[error("Header {0} is invalid")] + InvalidHeader(HeaderName), + + /// The value cannot be used as a header during request header construction + #[error("Invalid header value {0}")] + InvalidHeaderValue(#[from] InvalidHeaderValue), + + /// The header value is not valid utf-8 + #[error("header value is not a string")] + ToStr(#[from] ToStrError), + + /// Error in the request + #[error("request error: {0}")] + RequestError(#[from] ReqwestError), + + /// Error in the request + #[error("request error: {0}")] + RequestErrorWithBody(#[from] ReqwestErrorWithBody), + + /// Error parsing some range value + #[error("Cannot parse int")] + ParseIntError(#[from] ParseIntError), + + /// I/O Error + #[error("I/O error {0}")] + IoError(#[from] std::io::Error), + + /// We tried to download chunk too many times + #[error("Too many retries: {0}")] + TooManyRetries(Box), + + /// Semaphore cannot be acquired + #[error("Try acquire: {0}")] + TryAcquireError(#[from] TryAcquireError), + + /// Semaphore cannot be acquired + #[error("Acquire: {0}")] + AcquireError(#[from] AcquireError), + + /// Bad data from the API + #[error("Invalid Response: {0}")] + InvalidResponse(String), + + /// Repo exists, but the revision / oid doesn't exist. + #[error("Revision Not Found for url: {0}")] + RevisionNotFoundError(String), + + /// todo what is this? + #[error("Entry Not Found for url: {0}")] + EntryNotFoundError(String), + + /// Repo is gated + #[error("Cannot access gated repo for url: {0}")] + GatedRepoError(String), + + /// Repo is disabled + #[error("Cannot access repo - access to resource is disabled for url: {0}")] + DisabledRepoError(String), + + /// Repo does not exist for the caller (could be private) + #[error("Repository Not Found for url: {0}")] + RepositoryNotFoundError(String), +} + +/// Helper to create [`Api`] with all the options. +#[derive(Debug)] +pub struct ApiBuilder { + endpoint: String, + cache: Cache, + token: Option, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +impl Default for ApiBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ApiBuilder { + /// Default api builder + /// ``` + /// use hf_hub::api::tokio::ApiBuilder; + /// let api = ApiBuilder::new().build().unwrap(); + /// ``` + pub fn new() -> Self { + let cache = Cache::default(); + Self::from_cache(cache) + } + + /// From a given cache + /// ``` + /// use hf_hub::{api::tokio::ApiBuilder, Cache}; + /// let path = std::path::PathBuf::from("/tmp"); + /// let cache = Cache::new(path); + /// let api = ApiBuilder::from_cache(cache).build().unwrap(); + /// ``` + pub fn from_cache(cache: Cache) -> Self { + let token = cache.token(); + + let progress = true; + + Self { + endpoint: "https://huggingface.co".to_string(), + cache, + token, + max_files: num_cpus::get(), + chunk_size: 10_000_000, + parallel_failures: 0, + max_retries: 0, + progress, + } + } + + /// Wether to show a progressbar + pub fn with_progress(mut self, progress: bool) -> Self { + self.progress = progress; + self + } + + /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache = Cache::new(cache_dir); + self + } + + /// Sets the t to be used in the API + pub fn with_token(mut self, token: Option) -> Self { + self.token = token; + self + } + + fn build_headers(&self) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); + headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); + if let Some(token) = &self.token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {token}"))?, + ); + } + Ok(headers) + } + + /// Consumes the builder and builds the final [`Api`] + pub fn build(self) -> Result { + let headers = self.build_headers()?; + let client = Client::builder().default_headers(headers.clone()).build()?; + + // Policy: only follow relative redirects + // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 + let relative_redirect_policy = Policy::custom(|attempt| { + // Follow redirects up to a maximum of 10. + if attempt.previous().len() > 10 { + return attempt.error("too many redirects"); + } + + if let Some(last) = attempt.previous().last() { + // If the url is not relative + if last.make_relative(attempt.url()).is_none() { + return attempt.stop(); + } + } + + // Follow redirect + attempt.follow() + }); + + let relative_redirect_client = Client::builder() + .redirect(relative_redirect_policy) + .default_headers(headers) + .build()?; + Ok(Api { + endpoint: self.endpoint, + cache: self.cache, + client, + relative_redirect_client, + max_files: self.max_files, + chunk_size: self.chunk_size, + parallel_failures: self.parallel_failures, + max_retries: self.max_retries, + progress: self.progress, + }) + } +} + +#[derive(Debug)] +struct Metadata { + commit_hash: String, + etag: String, + size: usize, +} + +/// The actual Api used to interact with the hub. +/// You can inspect repos with [`Api::info`] +/// or download files with [`Api::download`] +#[derive(Clone, Debug)] +pub struct Api { + endpoint: String, + cache: Cache, + client: Client, + relative_redirect_client: Client, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +fn make_relative(src: &Path, dst: &Path) -> PathBuf { + let path = src; + let base = dst; + + assert_eq!( + path.is_absolute(), + base.is_absolute(), + "This function is made to look at absolute paths only" + ); + let mut ita = path.components(); + let mut itb = base.components(); + + loop { + match (ita.next(), itb.next()) { + (Some(a), Some(b)) if a == b => (), + (some_a, _) => { + // Ignoring b, because 1 component is the filename + // for which we don't need to go back up for relative + // filename to work. + let mut new_path = PathBuf::new(); + for _ in itb { + new_path.push(Component::ParentDir); + } + if let Some(a) = some_a { + new_path.push(a); + for comp in ita { + new_path.push(comp); + } + } + return new_path; + } + } + } +} + +fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { + if dst.exists() { + return Ok(()); + } + + let rel_src = make_relative(src, dst); + #[cfg(target_os = "windows")] + { + if std::os::windows::fs::symlink_file(rel_src, dst).is_err() { + std::fs::rename(src, dst)?; + } + } + + #[cfg(target_family = "unix")] + std::os::unix::fs::symlink(rel_src, dst)?; + + Ok(()) +} + +fn jitter() -> usize { + rand::thread_rng().gen_range(0..=500) +} + +fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { + (base_wait_time + n.pow(2) + jitter()).min(max) +} + +impl Api { + /// Creates a default Api, for Api options See [`ApiBuilder`] + pub fn new() -> Result { + ApiBuilder::new().build() + } + + /// Get the underlying api client + /// Allows for lower level access + pub fn client(&self) -> &Client { + &self.client + } + + async fn metadata(&self, url: &str) -> Result { + let response = self + .relative_redirect_client + .get(url) + .header(RANGE, "bytes=0-0") + .send() + .await?; + let response = response.error_for_status()?; + let headers = response.headers(); + let header_commit = HeaderName::from_static("x-repo-commit"); + let header_linked_etag = HeaderName::from_static("x-linked-etag"); + let header_etag = HeaderName::from_static("etag"); + + let etag = match headers.get(&header_linked_etag) { + Some(etag) => etag, + None => headers + .get(&header_etag) + .ok_or(ApiError::MissingHeader(header_etag))?, + }; + // Cleaning extra quotes + let etag = etag.to_str()?.to_string().replace('"', ""); + let commit_hash = headers + .get(&header_commit) + .ok_or(ApiError::MissingHeader(header_commit))? + .to_str()? + .to_string(); + + // The response was redirected o S3 most likely which will + // know about the size of the file + let response = if response.status().is_redirection() { + self.client + .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) + .header(RANGE, "bytes=0-0") + .send() + .await? + } else { + response + }; + let headers = response.headers(); + let content_range = headers + .get(CONTENT_RANGE) + .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? + .to_str()?; + + let size = content_range + .split('/') + .last() + .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? + .parse()?; + Ok(Metadata { + commit_hash, + etag, + size, + }) + } + + /// Creates a new handle [`ApiRepo`] which contains operations + /// on a particular [`Repo`] + pub fn repo(&self, repo: Repo) -> ApiRepo { + ApiRepo::new(self.clone(), repo) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Model)); + /// ``` + pub fn model(&self, model_id: String) -> ApiRepo { + self.repo(Repo::new(model_id, RepoType::Model)) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); + /// ``` + pub fn dataset(&self, model_id: String) -> ApiRepo { + self.repo(Repo::new(model_id, RepoType::Dataset)) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Space)); + /// ``` + pub fn space(&self, model_id: String) -> ApiRepo { + self.repo(Repo::new(model_id, RepoType::Space)) + } +} + +/// Shorthand for accessing things within a particular repo +#[derive(Debug)] +pub struct ApiRepo { + api: Api, + repo: Repo, +} + +impl ApiRepo { + fn new(api: Api, repo: Repo) -> Self { + Self { api, repo } + } +} diff --git a/src/api/tokio/repo_info.rs b/src/api/tokio/repo_info.rs new file mode 100644 index 0000000..07973c4 --- /dev/null +++ b/src/api/tokio/repo_info.rs @@ -0,0 +1,321 @@ +use crate::RepoType; + +use super::{Api, ApiError, ApiRepo, HfBadResponse}; + +#[derive(Debug)] +pub enum RepoInfo { + Model(ModelInfo), + // TODO add dataset and space info +} + +impl RepoInfo { + pub fn sha(&self) -> Option<&str> { + match self { + RepoInfo::Model(m) => m.sha.as_deref(), + } + } +} + +impl From for RepoInfo { + fn from(value: ModelInfo) -> Self { + Self::Model(value) + } +} + +impl ApiRepo { + /// Get the info object for a given repo. + pub async fn repo_info(&self) -> Result { + match self.repo.repo_type { + RepoType::Model => Ok(self + .api + .model_info(&self.repo.repo_id, Some(&self.repo.revision)) + .await? + .into()), + RepoType::Dataset => todo!(), + RepoType::Space => todo!(), + } + } + + /// Checks if this repository exists on the Hugging Face Hub. + pub async fn exists(&self) -> bool { + match self.repo_info().await { + Ok(_) => true, + // no access, but it exists + Err(ApiError::GatedRepoError(_)) => true, + Err(ApiError::RepositoryNotFoundError(_)) => false, + Err(_) => false, + } + } + + /// Checks if this repository exists and is writable on the Hugging Face Hub. + pub async fn is_writable(&self) -> bool { + if !self.exists().await { + return false; + } + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/x-ndjson".parse().unwrap()); + + let url = format!( + "{}/api/{}s/{}/commit/{}", + self.api.endpoint, + self.repo.repo_type.to_string(), + self.repo.url(), + self.repo.revision + ); + + let res: Result = (async { + Ok(self + .api + .client + .post(&url) + .headers(headers) + .send() + .await + .map_err(ApiError::from)? + .status()) + }) + .await; + if let Ok(status) = res { + if status == StatusCode::FORBIDDEN { + return false; + } + } + true + } +} + +impl Api { + /// Get info on one specific model on huggingface.co + /// + /// Model can be private if you pass an acceptable token or are logged in. + /// + /// Args: + /// repo_id (`str`): + /// A namespace (user or an organization) and a repo name separated + /// by a `/`. + /// revision (`str`, *optional*): + /// The revision of the model repository from which to get the + /// information. + async fn model_info( + &self, + repo_id: &str, + revision: Option<&str>, + ) -> Result { + let url = if let Some(revision) = revision { + format!( + "{}/api/models/{repo_id}/revision/{}", + self.endpoint, + urlencoding::encode(revision) + ) + } else { + format!("{}/api/models/{repo_id}", self.endpoint) + }; + + // TODO add params for security status, blobs, expand, etc. + + let model_info: ModelInfo = self + .client + .get(url) + .send() + .await? + .maybe_hf_err() + .await? + .json() + .await?; + + Ok(model_info) + } +} + +use http::{HeaderMap, StatusCode}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ModelInfo { + #[serde(default)] + pub _id: Option, + + #[serde(default)] + #[serde(alias = "modelId")] + pub model_id: Option, + + pub id: String, + + #[serde(default)] + pub author: Option, + + #[serde(default)] + pub sha: Option, + + #[serde(default)] + #[serde(alias = "createdAt", alias = "created_at")] + pub created_at: Option, + + #[serde(default)] + #[serde(alias = "lastModified", alias = "last_modified")] + pub last_modified: Option, + + #[serde(default)] + pub private: Option, + + #[serde(default)] + pub disabled: Option, + + #[serde(default)] + pub downloads: Option, + + #[serde(default)] + #[serde(alias = "downloadsAllTime")] + pub downloads_all_time: Option, + + #[serde(default)] + pub gated: Option, + + #[serde(default)] + pub gguf: Option>, + + #[serde(default)] + pub inference: Option, + + #[serde(default)] + pub likes: Option, + + #[serde(default)] + pub library_name: Option, + + #[serde(default)] + pub tags: Option>, + + #[serde(default)] + pub pipeline_tag: Option, + + #[serde(default)] + pub mask_token: Option, + + #[serde(default)] + #[serde(alias = "cardData", alias = "card_data")] + pub card_data: Option, + + #[serde(default)] + #[serde(alias = "widgetData")] + pub widget_data: Option, + + #[serde(default)] + #[serde(alias = "model-index", alias = "model_index")] + pub model_index: Option>, + + #[serde(default)] + pub config: Option>, + + #[serde(default)] + #[serde(alias = "transformersInfo", alias = "transformers_info")] + pub transformers_info: Option, + + #[serde(default)] + #[serde(alias = "trendingScore")] + pub trending_score: Option, + + #[serde(default)] + pub siblings: Option>, + + #[serde(default)] + pub spaces: Option>, + + #[serde(default)] + pub safetensors: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum GatedStatus { + Auto, + Manual, + False, +} + +impl<'de> Deserialize<'de> for GatedStatus { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct GatedStatusVisitor; + + impl<'de> serde::de::Visitor<'de> for GatedStatusVisitor { + type Value = GatedStatus; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string \"auto\", \"manual\", or boolean false") + } + + fn visit_bool(self, value: bool) -> Result + where + E: serde::de::Error, + { + if !value { + Ok(GatedStatus::False) + } else { + Err(E::custom("expected false")) + } + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "auto" => Ok(GatedStatus::Auto), + "manual" => Ok(GatedStatus::Manual), + _ => Err(E::custom("invalid value")), + } + } + } + + deserializer.deserialize_any(GatedStatusVisitor) + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum InferenceStatus { + Warm, + Cold, + Frozen, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RepoSibling { + pub rfilename: String, + #[serde(default)] + pub size: Option, + #[serde(alias = "blobId")] + #[serde(default)] + pub blob_id: Option, + #[serde(default)] + pub lfs: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct BlobLfsInfo { + pub size: i64, + pub sha256: String, + #[serde(alias = "pointerSize")] + pub pointer_size: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SafeTensorsInfo { + pub parameters: i64, + pub total: i64, +} + +// Note: You'll need to implement ModelCardData and TransformersInfo structs separately +#[derive(Debug, Serialize, Deserialize)] +pub struct ModelCardData { + // Add fields as needed +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransformersInfo { + // Add fields as needed +} diff --git a/src/api/tokio/upload/commit_api.rs b/src/api/tokio/upload/commit_api.rs new file mode 100644 index 0000000..d21ae93 --- /dev/null +++ b/src/api/tokio/upload/commit_api.rs @@ -0,0 +1,923 @@ +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use lazy_static::lazy_static; +use log::warn; +use regex::Regex; +use reqwest::header::HeaderMap; +use serde::{Deserialize, Serialize}; +use sha1::Sha1; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::fmt::Debug; +use std::fs; +use std::path::{Path, PathBuf}; +use thiserror::Error; +use tokio::fs::{read_to_string, File}; +use tokio::io::{self, AsyncRead, AsyncReadExt, BufReader}; + +use crate::api::tokio::upload::lfs::lfs_upload; +use crate::api::tokio::{ApiError, ApiRepo, HfBadResponse}; + +use super::commit_info::{CommitInfo, InvalidHfIdError}; + +const CHUNK_SIZE: usize = 8192; // 8KB chunks for streaming +const SAMPLE_SIZE: usize = 1024; // 1KB sample size + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UploadMode { + Lfs, + Regular, +} + +#[derive(Debug, Serialize)] +struct PreuploadFile { + path: String, + sample: String, + size: u64, +} + +#[derive(Debug, Serialize)] +struct PreuploadRequest { + files: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + git_ignore: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PreuploadResponseFile { + path: String, + upload_mode: String, + should_ignore: bool, + oid: Option, +} + +#[derive(Debug, Deserialize)] +struct PreuploadResponse { + files: Vec, +} + +pub struct UploadInfo { + pub size: u64, + pub sample: Vec, + pub sha256: Vec, +} + +impl Debug for UploadInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UploadInfo") + .field("size", &self.size) + .field("sample", &hex::encode(&self.sample)) + .field("sha256", &hex::encode(&self.sha256)) + .finish() + } +} + +async fn process_stream(mut reader: R, size: u64) -> io::Result +where + R: AsyncRead + Unpin, +{ + let mut sample = vec![0u8; SAMPLE_SIZE.min(size as usize)]; + reader.read_exact(&mut sample).await?; + + let mut hasher = Sha256::new(); + hasher.update(&sample); // Hash the sample bytes too + let mut total_bytes = sample.len() as u64; // Start with `sample` size + let mut buffer = vec![0u8; CHUNK_SIZE]; + + loop { + let bytes_read = reader.read(&mut buffer).await?; + if bytes_read == 0 { + break; + } + hasher.update(&buffer[..bytes_read]); + total_bytes += bytes_read as u64; + } + + Ok(UploadInfo { + size: total_bytes, + sample, + sha256: hasher.finalize().to_vec(), + }) +} + +impl UploadInfo { + pub async fn from_file(path: &Path) -> io::Result { + let file = File::open(path).await?; + let metadata = file.metadata().await?; + let size = metadata.len(); + + let reader = BufReader::with_capacity(CHUNK_SIZE, file); + process_stream(reader, size).await + } + + pub async fn from_bytes(bytes: &[u8]) -> io::Result { + let cursor = std::io::Cursor::new(bytes); + process_stream(cursor, bytes.len() as u64).await + } +} + +#[derive(Debug)] +pub struct CommitOperationAdd { + pub path_in_repo: String, + pub upload_info: UploadInfo, + pub upload_mode: UploadMode, + pub should_ignore: bool, + pub remote_oid: Option, + // Store the source for streaming + pub(crate) source: UploadSource, +} + +/// Represents different sources for upload data. +/// +/// # Examples +/// +/// ``` +/// use std::path::PathBuf; +/// +/// let file_source = UploadSource::File(PathBuf::from("path/to/file.txt")); +/// let bytes_source = UploadSource::Bytes(vec![1, 2, 3, 4]); +/// let empty_source = UploadSource::Emptied; +/// ``` +pub enum UploadSource { + /// Contains a file path from which to read the upload data + File(PathBuf), + /// Contains the upload data directly as a byte vector + Bytes(Vec), + /// Represents a state where the upload source has been consumed or cleared + Emptied, +} + +impl Debug for UploadSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::File(arg0) => f.debug_tuple("File").field(arg0).finish(), + Self::Bytes(arg0) => f + .debug_tuple("Bytes") + .field(&format!("{} bytes", arg0.len())) + .finish(), + Self::Emptied => write!(f, "Emptied"), + } + } +} + +impl From<&Path> for UploadSource { + fn from(value: &Path) -> Self { + Self::File(value.into()) + } +} + +impl From for UploadSource { + fn from(value: PathBuf) -> Self { + Self::File(value) + } +} + +impl From> for UploadSource { + fn from(value: Vec) -> Self { + Self::Bytes(value) + } +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct CommitData { + commit_url: String, + commit_oid: String, + #[serde(default)] + pull_request_url: Option, +} + +impl CommitOperationAdd { + pub async fn from_file(path_in_repo: String, file_path: &Path) -> io::Result { + let upload_info = UploadInfo::from_file(file_path).await?; + Ok(Self { + path_in_repo, + upload_info, + upload_mode: UploadMode::Regular, + should_ignore: false, + remote_oid: None, + source: file_path.into(), + }) + } + + pub async fn from_bytes(path_in_repo: String, bytes: Vec) -> io::Result { + let upload_info = UploadInfo::from_bytes(&bytes).await?; + Ok(Self { + path_in_repo, + upload_info, + upload_mode: UploadMode::Regular, + should_ignore: false, + remote_oid: None, + source: UploadSource::Bytes(bytes), + }) + } + pub async fn from_upload_source( + path_in_repo: String, + upload_source: UploadSource, + ) -> io::Result { + match upload_source { + UploadSource::Emptied => Err(io::Error::new( + io::ErrorKind::NotFound, + "upload source was empty.".to_string(), + )), + UploadSource::Bytes(bytes) => CommitOperationAdd::from_bytes(path_in_repo, bytes).await, + UploadSource::File(file) => CommitOperationAdd::from_file(path_in_repo, &file).await, + } + } + + /// Return the OID of the local file. + /// + /// This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one. + /// If the file did not change, we won't upload it again to prevent empty commits. + /// + /// For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref). + /// + /// For regular files, the OID corresponds to the SHA1 of the file content. + /// + /// Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the + /// pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes + /// and more convenient client-side. + pub fn local_oid(&self) -> Option { + match self.upload_mode { + UploadMode::Lfs => Some(hex::encode(&self.upload_info.sha256)), + UploadMode::Regular => { + let data = match &self.source { + UploadSource::Bytes(b) => b, + UploadSource::Emptied => return None, + UploadSource::File(f) => &fs::read(f).unwrap(), + }; + Some(git_hash(data)) + } + } + } +} + +lazy_static! { + static ref REGEX_COMMIT_OID: Regex = Regex::new(r"[A-Fa-f0-9]{5,40}").unwrap(); +} + +#[derive(Debug)] +pub enum CommitOperation { + Add(CommitOperationAdd), +} + +impl From for CommitOperation { + fn from(value: CommitOperationAdd) -> Self { + Self::Add(value) + } +} + +#[derive(Debug, Error)] +pub enum CommitError { + #[error("no commit message passed")] + NoMessage, + #[error("invalid OID for parent commit")] + InvalidOid, + #[error("failed to parse huggingface ID: {0}")] + InvalidHuggingFaceId(#[from] InvalidHfIdError), + #[error("error from HF api: {0}")] + Api(#[from] ApiError), + #[error("i/o error: {0}")] + Io(#[from] io::Error), +} + +impl ApiRepo { + /// Creates a commit in the given repo, deleting & uploading files as needed. + /// + /// # Arguments + /// + /// * `operations` - Vector of operations to include in the commit (Add, Delete, Copy) + /// * `commit_message` - The summary (first line) of the commit + /// * `commit_description` - Optional description of the commit + /// * `revision` - The git revision to commit from (defaults to "main") + /// * `create_pr` - Whether to create a Pull Request + /// * `num_threads` - Number of concurrent threads for uploading files + /// * `parent_commit` - The OID/SHA of the parent commit + /// + /// # Returns + /// + /// Returns CommitInfo containing information about the newly created commit + pub async fn create_commit( + &self, + operations: Vec, + commit_message: String, + commit_description: Option, + create_pr: Option, + num_threads: Option, + parent_commit: Option, + ) -> Result { + // Validate inputs + if commit_message.is_empty() { + return Err(CommitError::NoMessage); + } + + if let Some(parent) = &parent_commit { + if !REGEX_COMMIT_OID.is_match(parent) { + return Err(CommitError::InvalidOid); + } + } + + log::trace!( + "create_commit got {} operations: {:?}", + operations.len(), + operations + ); + + let commit_description = commit_description.unwrap_or_default(); + let create_pr = create_pr.unwrap_or(false); + let num_threads = num_threads.unwrap_or(5); + + // Warn on overwriting operations + warn_on_overwriting_operations(&operations); + + // Split operations by type + let additions: Vec<_> = operations + .into_iter() + .map(|op| match op { + CommitOperation::Add(add) => add, + }) + .collect(); + + // todo copy + // let copies: Vec<_> = operations + // .iter() + // .filter_map(|op| match op { + // // todo one day + // // CommitOperation::Copy(copy) => Some(copy), + // _ => None, + // }) + // .collect(); + // let deletions = operations.len() - additions.len() - copies.len(); + + log::debug!( + "About to commit to the hub: {} addition(s), {} copie(s) and {} deletion(s).", + additions.len(), + 0, + 0 // copies.len(), + // deletions + ); + + // TODO Validate README.md metadata if present + + // Pre-upload LFS files + let additions = self + .preupload_lfs_files(additions, Some(create_pr), Some(num_threads), None) + .await + .map_err(CommitError::Api)?; + + // re-collect into operations, after lfs upload. + let operations: Vec = additions.into_iter().map(|a| a.into()).collect(); + log::trace!( + "after preuploading lfs files, have {} operations: {:?}", + operations.len(), + operations + ); + // Remove no-op operations + let operations: Vec<_> = operations + .into_iter() + .filter(|op| match op { + CommitOperation::Add(add) => { + if let (Some(remote_oid), Some(local_oid)) = (&add.remote_oid, &add.local_oid()) + { + if remote_oid == local_oid { + log::debug!( + "Skipping upload for '{}' as the file has not changed.", + add.path_in_repo + ); + return false; + } + } + true + } + }) + .collect(); + + if operations.is_empty() { + log::warn!( + "No files have been modified since last commit. Skipping to prevent empty commit." + ); + // Return latest commit info + let info = self.repo_info().await?; + let sha = info + .sha() + .ok_or(CommitError::Api(ApiError::InvalidResponse( + "no SHA returned from repo info".to_string(), + )))? + .to_string(); + return Ok(CommitInfo::new( + &format!("{}/{}/commit/{}", self.api.endpoint, self.repo.repo_id, sha), + &commit_description, + &commit_message, + sha, + )?); + } + + // Prepare and send commit + // TODO add copy + // let files_to_copy = self.fetch_files_to_copy(&copies, &revision).await?; + + let commit_payload = prepare_commit_payload( + &operations, + // TODO: &files_to_copy, + &commit_message, + &commit_description, + parent_commit.as_deref(), + ); + log::trace!( + "commit payload: {}", + serde_json::to_string(&commit_payload).unwrap() + ); + + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/x-ndjson".parse().unwrap()); + + let url = format!( + "{}/api/{}s/{}/commit/{}", + self.api.endpoint, + self.repo.repo_type.to_string(), + self.repo.url(), + self.repo.revision + ); + let mut params = HashMap::new(); + if create_pr { + params.insert("create_pr", "1"); + } + + let serialized_payload: Vec = payload_as_ndjson(commit_payload).flatten().collect(); + + let response = self + .api + .client + .post(&url) + .headers(headers) + .query(¶ms) + .body(serialized_payload) + .send() + .await + .map_err(ApiError::from)? + .maybe_hf_err() + .await + .map_err(ApiError::from)?; + + let commit_data: CommitData = response.json().await.map_err(|e| { + CommitError::Api(ApiError::InvalidResponse(format!( + "Failed to parse json from commit API: {e}" + ))) + })?; + let mut commit_info = CommitInfo::new( + &commit_data.commit_url, + &commit_description, + &commit_message, + commit_data.commit_oid, + ) + .map_err(|e| { + CommitError::Api(ApiError::InvalidResponse(format!( + "Bad commit data returned from API: {e}" + ))) + })?; + + if create_pr { + if let Some(pr_url) = commit_data.pull_request_url { + commit_info.set_pr_info(&pr_url).map_err(|_| { + CommitError::Api(ApiError::InvalidResponse(format!( + "Invalid PR URL {pr_url}" + ))) + })?; + } + } + + Ok(commit_info) + } + /// Pre-upload LFS files to S3 in preparation for a future commit. + pub async fn preupload_lfs_files( + &self, + mut additions: Vec, + create_pr: Option, + num_threads: Option, + gitignore_content: Option, + ) -> Result, ApiError> { + // Set default values + let create_pr = create_pr.unwrap_or(false); + let num_threads = num_threads.unwrap_or(5); + + // Check for gitignore content in additions if not provided + let mut gitignore_content = gitignore_content; + if gitignore_content.is_none() { + for addition in &additions { + if addition.path_in_repo == ".gitignore" { + gitignore_content = Some(read_to_string(addition.path_in_repo.clone()).await?); + break; + } + } + } + + // Fetch upload modes for new files + self.fetch_and_apply_upload_modes(&mut additions, create_pr, gitignore_content) + .await?; + + // Filter LFS files + let (lfs_files, small_files): (Vec<_>, Vec<_>) = additions + .into_iter() + .partition(|addition| addition.upload_mode == UploadMode::Lfs); + + // Filter out ignored files + let mut new_lfs_additions_to_upload = Vec::new(); + let mut ignored_count = 0; + for addition in lfs_files { + if addition.should_ignore { + ignored_count += 1; + log::debug!( + "Skipping upload for LFS file '{}' (ignored by gitignore file).", + addition.path_in_repo + ); + } else { + new_lfs_additions_to_upload.push(addition); + } + } + + if ignored_count > 0 { + log::info!( + "Skipped upload for {} LFS file(s) (ignored by gitignore file).", + ignored_count + ); + } + + // Upload LFS files + let uploaded_lfs_files = self + .upload_lfs_files(new_lfs_additions_to_upload, num_threads) + .await?; + Ok(small_files.into_iter().chain(uploaded_lfs_files).collect()) + } + + /// Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob + /// or as git LFS blob. Input `additions` are mutated in-place with the upload mode. + pub async fn fetch_and_apply_upload_modes( + &self, + additions: &mut Vec, + create_pr: bool, + gitignore_content: Option, + ) -> Result<(), ApiError> { + // Process in chunks of 256 + for chunk in additions.chunks_mut(256) { + let files: Vec = chunk + .iter() + .map(|op| PreuploadFile { + path: op.path_in_repo.clone(), + sample: BASE64.encode(&op.upload_info.sample), + size: op.upload_info.size, + }) + .collect(); + + let payload = PreuploadRequest { + files, + git_ignore: gitignore_content.clone(), + }; + + let mut url = self.preupload_url(); + + if create_pr { + url.push_str("?create_pr=1"); + } + + let preupload_info: PreuploadResponse = self + .api + .client + .post(&url) + .json(&payload) + .send() + .await? + .maybe_hf_err() + .await? + .json() + .await?; + + // Update the operations with the response information + for file_info in preupload_info.files { + if let Some(op) = chunk + .iter_mut() + .find(|op| op.path_in_repo == file_info.path) + { + op.upload_mode = match file_info.upload_mode.as_str() { + "lfs" => UploadMode::Lfs, + "regular" => UploadMode::Regular, + m => { + return Err(ApiError::InvalidResponse(format!( + "Bad upload mode {m} returned from preupload info." + ))) + } + }; + op.should_ignore = file_info.should_ignore; + op.remote_oid = file_info.oid; + } + } + } + + // Handle empty files + for addition in additions.iter_mut() { + if addition.upload_info.size == 0 { + addition.upload_mode = UploadMode::Regular; + } + } + + Ok(()) + } + + /// Uploads the content of `additions` to the Hub using the large file storage protocol. + /// Relevant external documentation: + /// - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + + async fn upload_lfs_files( + &self, + additions: Vec, + num_threads: usize, + ) -> Result, ApiError> { + // Step 1: Retrieve upload instructions from LFS batch endpoint + let mut batch_objects = Vec::new(); + + // Process in chunks of 256 files + for chunk in additions.chunks(256) { + let mut batch_info = self.post_lfs_batch_info(chunk).await?; + let errs: Vec<_> = batch_info + .iter_mut() + .flat_map(|b| b.error.take().map(|e| (b.oid.clone(), e))) + .collect(); + if !errs.is_empty() { + return Err(ApiError::InvalidResponse( + errs.into_iter() + .map(|e| { + format!( + "Encountered error for file with OID {}: `{}`)", + e.0, e.1.message + ) + }) + .collect::>() + .join("\n"), + )); + } + batch_objects.extend(batch_info); + } + + // Create mapping of OID to addition operation + let mut oid2addop: HashMap = additions + .into_iter() + .map(|op| { + let oid = hex::encode(&op.upload_info.sha256); + (oid, op) + }) + .collect(); + + // Step 2: Filter out already uploaded files + let filtered_actions: Vec<_> = batch_objects + .into_iter() + .filter(|action| { + if action.actions.is_none() { + if let Some(op) = oid2addop.get(&action.oid) { + log::debug!( + "Content of file {} is already present upstream - skipping upload.", + op.path_in_repo + ); + } + false + } else { + true + } + }) + .collect(); + + if filtered_actions.is_empty() { + log::debug!("No LFS files to upload."); + return Ok(oid2addop.into_values().collect()); + } + + let s3_client = reqwest::Client::new(); + + // Step 3: Upload files concurrently + let endpoint = self.api.endpoint.clone(); + let upload_futures: Vec<_> = filtered_actions + .into_iter() + .map(|batch_action| { + let operation = oid2addop.remove(&batch_action.oid).unwrap(); + lfs_upload( + self.api.client.clone(), + s3_client.clone(), + operation, + batch_action, + endpoint.clone(), + ) + }) + .collect(); + + log::debug!( + "Uploading {} LFS files to the Hub using up to {} threads concurrently", + upload_futures.len(), + num_threads + ); + + // Use tokio::spawn to handle concurrent uploads + let handles: Vec<_> = upload_futures + .into_iter() + .map(|future| tokio::spawn(future)) + .collect(); + + let mut operations: Vec<_> = oid2addop.drain().map(|(_k, v)| v).collect(); + for handle in handles { + log::trace!("joining handle.."); + operations.push(handle.await.map_err(|e| { + ApiError::IoError(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + format!("failed to join lfs upload thread {e}"), + )) + })??); + } + + log::debug!("Uploaded {} LFS files to the Hub.", operations.len(),); + + Ok(operations) + } +} + +/// Computes the git-sha1 hash of the given bytes, using the same algorithm as git. +/// +/// This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object +/// for more details. +/// +/// Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the +/// pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of +/// the LFS file content when we want to compare LFS files. +/// +/// # Returns +/// +/// The git-hash of `data` as a hexadecimal string. +/// +/// # Example +/// +/// ``` +/// let hash = git_hash(b"Hello, World!"); +/// assert_eq!(hash, "b45ef6fec89518d314f546fd6c3025367b721684"); +/// ``` +pub fn git_hash(data: &[u8]) -> String { + let mut hasher = Sha1::new(); + + // Add header + hasher.update(b"blob "); + hasher.update(data.len().to_string().as_bytes()); + hasher.update(b"\0"); + + // Add data + hasher.update(data); + + // Convert to hex string + format!("{:x}", hasher.finalize()) +} + +/// Warn user when a list of operations is expected to overwrite itself in a single +/// commit. +/// +/// Rules: +/// - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning +/// message is triggered. +/// - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted +/// by a `CommitOperationDelete`, a warning is triggered. +/// - If a `CommitOperationDelete` deletes a filepath that is then updated by a +/// `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to +/// delete before upload) but can happen if a user deletes an entire folder and then +/// add new files to it. +fn warn_on_overwriting_operations(operations: &[CommitOperation]) { + let mut nb_additions_per_path: HashMap = HashMap::new(); + for operation in operations { + // i know it's irrefutable, but later we're gonna add more operations, so we'll if-let or match then. + let CommitOperation::Add(CommitOperationAdd { path_in_repo, .. }) = operation; + { + if nb_additions_per_path.contains_key(path_in_repo) { + warn!( + "About to update multiple times the same file in the same commit: '{path_in_repo}'. This can cause undesired inconsistencies in your repo." + ); + } + + *nb_additions_per_path + .entry(path_in_repo.clone()) + .or_insert(0) += 1 + } + } +} + +#[derive(Serialize, Debug)] +struct HeaderValue { + summary: String, + description: String, + #[serde(skip_serializing_if = "Option::is_none")] + parent_commit: Option, +} + +#[derive(Serialize, Debug)] +struct FileValue { + content: String, + path: String, + encoding: String, +} + +#[derive(Serialize, Debug)] +struct LfsFileValue { + path: String, + algo: String, + oid: String, + size: u64, +} + +// todo add +// #[derive(Serialize)] +// struct DeletedValue { +// path: String, +// } + +#[derive(Serialize, Debug)] +#[serde(tag = "key", content = "value")] +#[serde(rename_all = "camelCase")] +enum CommitPayloadItem { + Header(HeaderValue), + File(FileValue), + LfsFile(LfsFileValue), + // todo add + // DeletedFile(DeletedValue), + // DeletedFolder(DeletedValue), +} + +fn prepare_commit_payload( + operations: &[CommitOperation], + // TODO: add copy functionality + // files_to_copy: &[], + commit_message: &str, + commit_description: &str, + parent_commit: Option<&str>, +) -> Vec { + let mut payload = Vec::new(); + + // 1. Send header item with commit metadata + payload.push(CommitPayloadItem::Header(HeaderValue { + summary: commit_message.to_string(), + description: commit_description.to_string(), + parent_commit: parent_commit.map(String::from), + })); + + let mut nb_ignored_files = 0; + + // 2. Send operations, one per line + for operation in operations { + match operation { + // 2.a and 2.b: Adding files (regular or LFS) + CommitOperation::Add(add_op) => { + if add_op.should_ignore { + log::debug!( + "Skipping file '{}' in commit (ignored by gitignore file).", + add_op.path_in_repo + ); + nb_ignored_files += 1; + continue; + } + + match &add_op.upload_mode { + UploadMode::Regular => { + let content = match &add_op.source { + UploadSource::Bytes(bytes) => BASE64.encode(bytes), + UploadSource::File(path) => { + BASE64.encode(std::fs::read(path).unwrap()) // TODO: proper error handling + } + UploadSource::Emptied => continue, + }; + + payload.push(CommitPayloadItem::File(FileValue { + content, + path: add_op.path_in_repo.clone(), + encoding: "base64".to_string(), + })); + } + UploadMode::Lfs => { + payload.push(CommitPayloadItem::LfsFile(LfsFileValue { + path: add_op.path_in_repo.clone(), + algo: "sha256".to_string(), + oid: hex::encode(&add_op.upload_info.sha256), + size: add_op.upload_info.size, + })); + } + } + } // TODO: Add other operations when implemented + } + } + + if nb_ignored_files > 0 { + log::info!( + "Skipped {} file(s) in commit (ignored by gitignore file).", + nb_ignored_files + ); + } + + payload +} + +fn payload_as_ndjson(payload: Vec) -> impl Iterator> { + payload.into_iter().flat_map(|item| { + let mut json = serde_json::to_vec(&item).unwrap(); + json.push(b'\n'); + vec![json] + }) +} diff --git a/src/api/tokio/upload/commit_info.rs b/src/api/tokio/upload/commit_info.rs new file mode 100644 index 0000000..8bea425 --- /dev/null +++ b/src/api/tokio/upload/commit_info.rs @@ -0,0 +1,244 @@ +use std::{error::Error, fmt, num::ParseIntError}; + +use crate::RepoType; +use lazy_static::lazy_static; +use regex::Regex; + +#[derive(Debug)] +pub struct RepoUrl { + pub endpoint: String, + pub namespace: Option, + pub repo_name: String, + pub repo_id: String, + pub repo_type: Option, + pub url: String, +} + +const HF_DEFAULT_ENDPOINT: &str = "https://huggingface.co"; +const HF_DEFAULT_STAGING_ENDPOINT: &str = "https://hub-ci.huggingface.co"; + +impl RepoUrl { + pub fn new(url: &str) -> Result { + Self::new_with_endpoint(url, HF_DEFAULT_ENDPOINT) + } + pub fn new_with_endpoint(url: &str, endpoint: &str) -> Result { + let url = fix_hf_endpoint_in_url(url, endpoint); + let (repo_type, namespace, repo_name) = + repo_type_and_name_from_hf_id(&url, Some(endpoint))?; + let repo_id = if let Some(ns) = &namespace { + format!("{ns}/{repo_name}") + } else { + repo_name.clone() + }; + + Ok(Self { + url, + endpoint: endpoint.into(), + namespace, + repo_id, + repo_type, + repo_name, + }) + } +} + +#[derive(Debug)] +pub struct InvalidHfIdError(String); + +impl fmt::Display for InvalidHfIdError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Unable to retrieve user and repo ID from the passed HF ID: {}", + self.0 + ) + } +} + +impl Error for InvalidHfIdError {} + +/// Returns the repo type and ID from a huggingface.co URL linking to a repository +/// +/// # Arguments +/// +/// * `hf_id` - An URL or ID of a repository on the HF hub. Accepted values are: +/// - https://huggingface.co/// +/// - https://huggingface.co// +/// - hf://// +/// - hf:/// +/// - // +/// - / +/// - +/// * `hub_url` - The URL of the HuggingFace Hub, defaults to https://huggingface.co +/// +/// # Returns +/// +/// A tuple with three items: (repo_type, namespace, repo_name) +fn repo_type_and_name_from_hf_id( + hf_id: &str, + hub_url: Option<&str>, +) -> Result<(Option, Option, String), InvalidHfIdError> { + let hub_url = hub_url.unwrap_or(HF_DEFAULT_ENDPOINT); + let hub_url = Regex::new(r"https?://") + .unwrap() + .replace(hub_url, "") + .into_owned(); + + let is_hf_url = hf_id.contains(&hub_url) && !hf_id.contains('@'); + + const HFFS_PREFIX: &str = "hf://"; + let hf_id = hf_id.strip_prefix(HFFS_PREFIX).unwrap_or(hf_id); + + let url_segments: Vec<&str> = hf_id.split('/').collect(); + let is_hf_id = url_segments.len() <= 3; + + let (repo_type, namespace, repo_id) = if is_hf_url { + let (namespace, repo_id) = ( + url_segments[url_segments.len() - 2], + url_segments.last().unwrap(), + ); + let namespace = if namespace == hub_url { + None + } else { + Some(namespace.to_string()) + }; + + let repo_type: Option = + if url_segments.len() > 2 && !url_segments[url_segments.len() - 3].contains(&hub_url) { + url_segments[url_segments.len() - 3] + .to_string() + .parse() + .ok() + } else { + namespace + .clone() + .unwrap_or("".to_string()) + .parse::() + .ok() + }; + + (repo_type, namespace, repo_id.to_string()) + } else if is_hf_id { + match url_segments.len() { + 3 => { + let (repo_type, namespace, repo_id) = ( + url_segments[0].parse().ok(), + Some(url_segments[1].to_string()), + url_segments[2].to_string(), + ); + (repo_type, namespace, repo_id) + } + 2 => { + if let Ok(repo_type) = url_segments[0].parse() { + (Some(repo_type), None, url_segments[1].to_string()) + } else { + ( + None, + Some(url_segments[0].to_string()), + url_segments[1].to_string(), + ) + } + } + 1 => (None, None, url_segments[0].to_string()), + _ => return Err(InvalidHfIdError(hf_id.to_string())), + } + } else { + return Err(InvalidHfIdError(hf_id.to_string())); + }; + + Ok((repo_type, namespace, repo_id)) +} + +/// Replace the default endpoint in a URL by a custom one. +/// This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint. +pub fn fix_hf_endpoint_in_url(url: &str, endpoint: &str) -> String { + // check if a proxy has been set => if yes, update the returned URL to use the proxy + let mut url = url.to_string(); + if endpoint != HF_DEFAULT_ENDPOINT { + url = url.replace(HF_DEFAULT_ENDPOINT, endpoint); + } else if endpoint != HF_DEFAULT_STAGING_ENDPOINT { + url = url.replace(HF_DEFAULT_STAGING_ENDPOINT, endpoint); + } + url +} + +/// Data structure containing information about a newly created commit. +/// Returned by any method that creates a commit on the Hub. +#[derive(Debug)] +pub struct CommitInfo { + /// Url where to find the commit. + pub commit_url: String, + /// The summary (first line) of the commit that has been created. + pub commit_message: String, + /// Description of the commit that has been created. Can be empty. + pub commit_description: String, + /// Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. + pub oid: String, + + /// Repo URL of the commit containing info like repo_id, repo_type, etc. + pub repo_url: RepoUrl, + + // Info about the associated pull request + pub pull_request: Option, +} +#[derive(Debug)] +pub struct PullRequestInfo { + pub url: String, + pub revision: String, + pub num: u32, +} + +impl PullRequestInfo { + fn new(pr_url: &str) -> Result { + let pr_revision = parse_revision_from_pr_url(pr_url); + let pr_num: u32 = pr_revision.split("/").last().unwrap().parse()?; + Ok(PullRequestInfo { + num: pr_num, + revision: pr_revision, + url: pr_url.into(), + }) + } +} + +lazy_static! { + static ref REGEX_DISCUSSION_URL: Regex = Regex::new(r".*/discussions/(\d+)$").unwrap(); +} + +/// Safely parse revision number from a PR url. +/// # Example +/// ``` +/// assert_eq!(parse_revision_from_pr_url("https://huggingface.co/bigscience/bloom/discussions/2"), "refs/pr/2"); +/// ``` +fn parse_revision_from_pr_url(pr_url: &str) -> String { + let re_match = REGEX_DISCUSSION_URL.captures(pr_url).unwrap_or_else(|| { + panic!( + "Unexpected response from the hub, expected a Pull Request URL but got: '{}'", + pr_url + ) + }); + + format!("refs/pr/{}", &re_match[1]) +} + +impl CommitInfo { + pub fn new( + url: &str, + commit_description: &str, + commit_message: &str, + oid: String, + ) -> Result { + Ok(Self { + commit_url: url.into(), + commit_description: commit_description.into(), + commit_message: commit_message.into(), + oid, + pull_request: None, + repo_url: RepoUrl::new(url)?, + }) + } + + pub fn set_pr_info(&mut self, pr_url: &str) -> Result<(), ParseIntError> { + self.pull_request = Some(PullRequestInfo::new(pr_url)?); + Ok(()) + } +} diff --git a/src/api/tokio/upload/completion_payload.rs b/src/api/tokio/upload/completion_payload.rs new file mode 100644 index 0000000..e963f21 --- /dev/null +++ b/src/api/tokio/upload/completion_payload.rs @@ -0,0 +1,40 @@ +use http::HeaderMap; +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct CompletionPayload { + oid: String, + parts: Vec, +} + +#[derive(Debug, Serialize)] +struct PayloadPart { + #[serde(rename = "partNumber")] + part_number: u32, + etag: String, +} + +pub fn get_completion_payload(response_headers: &[HeaderMap], sha256: &[u8]) -> CompletionPayload { + let parts: Vec = response_headers + .iter() + .enumerate() + .map(|(part_number, headers)| { + let etag = headers + .get("etag") + .and_then(|h| h.to_str().ok()) + .filter(|&s| !s.is_empty()) + .ok_or_else(|| format!("Invalid etag returned for part {}", part_number + 1)) + .unwrap(); // You might want to handle this error differently + + PayloadPart { + part_number: (part_number + 1) as u32, + etag: etag.to_string(), + } + }) + .collect(); + + CompletionPayload { + oid: hex::encode(sha256), + parts, + } +} diff --git a/src/api/tokio/upload/lfs.rs b/src/api/tokio/upload/lfs.rs new file mode 100644 index 0000000..3184d32 --- /dev/null +++ b/src/api/tokio/upload/lfs.rs @@ -0,0 +1,386 @@ +use http::HeaderMap; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, io::ErrorKind}; +use tokio::io::{AsyncReadExt, AsyncSeekExt}; + +use crate::{ + api::tokio::{ApiError, ApiRepo, HfBadResponse}, + RepoType, +}; + +use super::{ + commit_api::{CommitOperationAdd, UploadInfo, UploadSource}, + commit_info::fix_hf_endpoint_in_url, + completion_payload::get_completion_payload, +}; + +#[derive(Debug, Deserialize)] +pub struct BatchInfo { + pub objects: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct BatchObject { + pub oid: String, + pub size: u64, + #[serde(default)] + pub error: Option, + pub actions: Option, +} + +#[derive(Debug, Deserialize)] +pub struct BatchError { + pub code: i32, + pub message: String, +} + +#[derive(Debug, Deserialize)] +pub struct BatchActions { + pub upload: Option, + pub verify: Option, +} + +#[derive(Debug, Deserialize)] +pub struct LfsAction { + pub href: String, + pub header: Option>, +} + +#[derive(Debug, Serialize)] +struct BatchRequest { + operation: String, + transfers: Vec, + objects: Vec, + hash_algo: String, + #[serde(skip_serializing_if = "Option::is_none")] + r#ref: Option, +} + +#[derive(Debug, Serialize)] +struct BatchRequestObject { + oid: String, + size: u64, +} + +impl From for BatchRequestObject { + fn from(value: BatchObject) -> Self { + Self { + oid: value.oid, + size: value.size, + } + } +} + +#[derive(Debug, Serialize)] +struct BatchRequestRef { + name: String, +} + +fn lfs_endpoint(repo_type: RepoType, repo_id: &str) -> String { + let prefix = match repo_type { + RepoType::Model => "", + RepoType::Dataset => "datasets/", + RepoType::Space => "spaces/", + }; + format!("{}{}.git", prefix, repo_id) +} + +impl ApiRepo { + /// Requests the LFS batch endpoint to retrieve upload instructions + /// + /// Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + pub(crate) async fn post_lfs_batch_info( + &self, + additions: &[CommitOperationAdd], + ) -> Result, ApiError> { + let batch_url = format!( + "{}/{}/info/lfs/objects/batch", + self.api.endpoint, + lfs_endpoint(self.repo.repo_type, &self.repo.repo_id) + ); + + let objects: Vec = additions + .iter() + .map(|op| BatchRequestObject { + oid: hex::encode(&op.upload_info.sha256), + size: op.upload_info.size, + }) + .collect(); + + let payload = BatchRequest { + operation: "upload".to_string(), + transfers: vec!["basic".to_string(), "multipart".to_string()], + objects, + hash_algo: "sha256".to_string(), + r#ref: None, // Add revision handling if needed + }; + + let headers = make_lfs_headers(); + + let response = self + .api + .client + .post(&batch_url) + .headers(headers) + .json(&payload) + .send() + .await? + .maybe_hf_err() + .await?; + + let batch_info: BatchInfo = response.json().await?; + Ok(batch_info.objects) + } +} + +fn make_lfs_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert("Accept", "application/vnd.git-lfs+json".parse().unwrap()); + headers.insert( + "Content-Type", + "application/vnd.git-lfs+json".parse().unwrap(), + ); + headers +} + +fn get_sorted_parts_urls( + headers: &HashMap, + upload_info: &UploadInfo, + chunk_size: u64, +) -> Result, ApiError> { + let mut part_urls: Vec<(u32, String)> = headers + .iter() + .filter_map(|(key, value)| { + if let Ok(part_num) = key.parse::() { + Some((part_num, value.clone())) + } else { + None + } + }) + .collect(); + + part_urls.sort_by_key(|&(num, _)| num); + let sorted_urls: Vec = part_urls.into_iter().map(|(_, url)| url).collect(); + + let expected_parts = ((upload_info.size as f64) / (chunk_size as f64)).ceil() as usize; + log::trace!("chunk size is {chunk_size} and the whole file is {size}, and {size}/{chunk_size} === {result}", size = upload_info.size, result=expected_parts); + if sorted_urls.len() != expected_parts { + return Err(ApiError::InvalidResponse( + "Invalid server response to upload large LFS file".into(), + )); + } + + Ok(sorted_urls) +} + +async fn upload_part( + client: Client, + part_upload_url: &str, + data: Vec, +) -> Result { + let l = data.len(); + log::trace!("uploading part ({} bytes)", l); + let response = client + .put(part_upload_url) + .body(data) + .send() + .await? + .maybe_hf_err() + .await?; + + log::trace!("uploaded ({} bytes)", l); + Ok(response) +} + +async fn upload_multi_part( + hf_client: Client, + s3_client: Client, + operation: CommitOperationAdd, + headers: &HashMap, + chunk_size: u64, + upload_url: &str, +) -> Result { + // 1. Get upload URLs for each part + log::trace!("getting upload urls.."); + let sorted_parts_urls = get_sorted_parts_urls(headers, &operation.upload_info, chunk_size)?; + log::trace!("got upload URLs: {sorted_parts_urls:?}."); + + // 2. Upload parts + log::trace!("uploading parts..."); + let sha256 = operation.upload_info.sha256.clone(); + let (operation, response_headers) = + upload_parts_iteratively(s3_client, operation, &sorted_parts_urls, chunk_size).await?; + log::trace!("parts uploaded."); + + // 3. Send completion request + let completion_payload = get_completion_payload(&response_headers, &sha256); + log::trace!("sending completion request: {completion_payload:?}"); + + let headers = make_lfs_headers(); + + let response = hf_client + .post(upload_url) + .headers(headers) + .json(&completion_payload) + .send() + .await? + .maybe_hf_err() + .await?; + + log::trace!("completion response: {:?}", response.text().await?); + + Ok(operation) +} + +async fn upload_parts_iteratively( + client: Client, + mut operation: CommitOperationAdd, + sorted_parts_urls: &[String], + chunk_size: u64, +) -> Result<(CommitOperationAdd, Vec), ApiError> { + let mut response_headers = Vec::new(); + + match &operation.source { + UploadSource::File(path) => { + let file = tokio::fs::File::open(path).await?; + let mut reader = tokio::io::BufReader::new(file); + + for (part_idx, part_upload_url) in sorted_parts_urls.iter().enumerate() { + let mut buffer = vec![0u8; chunk_size as usize]; + let start_pos = part_idx as u64 * chunk_size; + log::trace!("uploading path {path:?} chunk {part_idx}, start_pos {start_pos}"); + reader.seek(std::io::SeekFrom::Start(start_pos)).await?; + + // read either until the chunk is done or we hit EoF + let bytes_read = { + let mut bytes_read = 0; + while bytes_read < chunk_size as usize { + match reader.read(&mut buffer[bytes_read..]).await? { + 0 => break, // EOF reached + n => bytes_read += n, + } + } + bytes_read + }; + buffer.truncate(bytes_read); + + let response = upload_part(client.clone(), part_upload_url, buffer).await?; + response_headers.push(response.headers().clone()); + } + } + UploadSource::Bytes(bytes) => { + for (part_idx, part_upload_url) in sorted_parts_urls.iter().enumerate() { + let start = (part_idx as u64 * chunk_size) as usize; + let end = ((part_idx + 1) as u64 * chunk_size) as usize; + let chunk = bytes[start..std::cmp::min(end, bytes.len())].to_vec(); + + let response = upload_part(client.clone(), part_upload_url, chunk).await?; + response_headers.push(response.headers().clone()); + } + } + UploadSource::Emptied => { + return Err(ApiError::IoError(std::io::Error::new( + ErrorKind::NotFound, + "File has already been emptied!", + ))); + } + } + + operation.source = UploadSource::Emptied; + + Ok((operation, response_headers)) +} + +async fn upload_single_part( + s3_client: Client, + mut operation: CommitOperationAdd, + upload_url: &str, +) -> Result { + let body = match &operation.source { + UploadSource::File(path) => tokio::fs::read(path).await?, + UploadSource::Bytes(bytes) => bytes.clone(), + UploadSource::Emptied => { + return Err(ApiError::IoError(std::io::Error::new( + ErrorKind::NotFound, + "File has already been emptied!".to_string(), + ))); + } + }; + + let _ = s3_client + .put(upload_url) + .body(body) + .send() + .await? + .maybe_hf_err() + .await?; + + operation.source = UploadSource::Emptied; + + Ok(operation) +} + +pub(crate) async fn lfs_upload( + hf_client: Client, + s3_client: Client, + operation: CommitOperationAdd, + batch_action: BatchObject, + endpoint: String, +) -> Result { + // Skip if already uploaded + if batch_action.actions.is_none() { + log::debug!( + "Content of file {} is already present upstream - skipping upload", + operation.path_in_repo + ); + return Ok(operation); + } + let path = operation.path_in_repo.clone(); + + let actions = batch_action.actions.as_ref().unwrap(); + let upload_action = actions.upload.as_ref().ok_or_else(|| { + ApiError::InvalidResponse("Missing upload action in LFS batch response".into()) + })?; + + let multipart = upload_action.header.as_ref().and_then(|h| { + h.get("chunk_size") + .and_then(|size| size.parse::().ok().map(|s| (h, s))) + }); + let finished_upload = if let Some((header, chunk_size)) = multipart { + const ONE_MB: u64 = 1024 * 1024; + const MIN_UPLOAD_SIZE: u64 = 5 * ONE_MB; + if chunk_size < MIN_UPLOAD_SIZE { + return Err(ApiError::InvalidResponse(format!("API gave us a chunk size of {chunk_size}, but the smallest allowed chunk size for AWS multipart uploads is {MIN_UPLOAD_SIZE}"))); + } else { + log::trace!("chunk size for {path}: {chunk_size}") + } + // Handle multipart upload if chunk_size is present + log::debug!("starting multipart upload for {path}"); + upload_multi_part( + hf_client.clone(), + s3_client.clone(), + operation, + header, + chunk_size, + &upload_action.href, + ) + .await + } else { + // Fall back to single-part upload + log::debug!("starting single-part upload for {}", path); + upload_single_part(s3_client.clone(), operation, &upload_action.href).await + }?; + + if let Some(verify) = &actions.verify { + log::debug!("running verify for {}", path); + let verify_url = fix_hf_endpoint_in_url(&verify.href, &endpoint); + let verify_body: BatchRequestObject = batch_action.into(); + let res = hf_client.post(verify_url).json(&verify_body).send().await?; + log::debug!("verify result: {}", res.text().await?) + } + + log::debug!("{}: Upload successful", path); + + Ok(finished_upload) +} diff --git a/src/api/tokio/upload/mod.rs b/src/api/tokio/upload/mod.rs new file mode 100644 index 0000000..fa6ee7f --- /dev/null +++ b/src/api/tokio/upload/mod.rs @@ -0,0 +1,70 @@ +use super::ApiRepo; +use commit_api::CommitOperationAdd; +use commit_info::CommitInfo; +use futures::future::join_all; + +pub use commit_api::{CommitError, UploadSource}; + +mod commit_api; +mod commit_info; +mod completion_payload; +mod lfs; + +impl ApiRepo { + /// Upload a local file (up to 50 GB) to the given repo. The upload is done + /// through an HTTP post request, and doesn't require git or git-lfs to be + /// installed. + pub async fn upload_file( + &self, + source: impl Into, + path_in_repo: &str, + commit_message: Option, + commit_description: Option, + create_pr: bool, + ) -> Result { + self.upload_files( + vec![(source.into(), path_in_repo.to_string())], + commit_message, + commit_description, + create_pr, + ) + .await + } + + /// Upload multiple local files (up to 50 GB each) to the given repo. The upload is done + /// through an HTTP post request, and doesn't require git or git-lfs to be + /// installed. + pub async fn upload_files( + &self, + files: Vec<(UploadSource, String)>, + commit_message: Option, + commit_description: Option, + create_pr: bool, + ) -> Result { + let commit_message = + commit_message.unwrap_or_else(|| format!("Upload {} files with hf_hub", files.len())); + + let operations = join_all( + files + .into_iter() + .map(|(source, path)| CommitOperationAdd::from_upload_source(path, source)), + ) + .await + .into_iter() + .map(|operation| operation.map(|o| o.into())) + .collect::>()?; + + let commit_info = self + .create_commit( + operations, + commit_message, + commit_description, + Some(create_pr), + None, + None, + ) + .await?; + + Ok(commit_info) + } +} diff --git a/src/lib.rs b/src/lib.rs index 8f5a03d..014b6ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ use rand::{distributions::Alphanumeric, Rng}; use std::io::Write; use std::path::PathBuf; +use std::str::FromStr; /// The actual Api to interact with the hub. #[cfg(any(feature = "tokio", feature = "ureq"))] @@ -21,6 +22,32 @@ pub enum RepoType { Space, } +impl ToString for RepoType { + fn to_string(&self) -> String { + match self { + Self::Dataset => "dataset".to_string(), + Self::Model => "model".to_string(), + Self::Space => "space".to_string(), + } + } +} + +impl FromStr for RepoType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "dataset" => Ok(Self::Dataset), + "datasets" => Ok(Self::Dataset), + "model" => Ok(Self::Model), + "models" => Ok(Self::Model), + "space" => Ok(Self::Space), + "spaces" => Ok(Self::Space), + _ => Err(format!("Invalid repo type {s}.")), + } + } +} + /// A local struct used to fetch information from the cache folder. #[derive(Clone, Debug)] pub struct Cache {