Skip to content

Commit

Permalink
Merge pull request #280 from L-jasmine/fix/download_urls
Browse files Browse the repository at this point in the history
Fix/download urls
  • Loading branch information
juntao authored Oct 12, 2024
2 parents 2769ba1 + 4eb5134 commit 267ade1
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 41 deletions.
35 changes: 23 additions & 12 deletions moly-backend/src/backend_impls/api_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,12 @@ impl BackendModel for LLamaEdgeApiServer {
let (wasm_module, listen_addr) = if let Some(old_model) = &old_model {
let listen_addr = load_model_options.override_server_address.clone().map_or(
old_model.listen_addr,
|addr| {
match std::net::TcpListener::bind(&addr) {
Ok(listener) => listener.local_addr().unwrap(),
Err(_) => {
eprintln!("Failed to start the model on address {}", addr);
eprintln!("Using the previous one {}", old_model.listen_addr);
old_model.listen_addr
}
|addr| match std::net::TcpListener::bind(&addr) {
Ok(listener) => listener.local_addr().unwrap(),
Err(_) => {
eprintln!("Failed to start the model on address {}", addr);
eprintln!("Using the previous one {}", old_model.listen_addr);
old_model.listen_addr
}
},
);
Expand All @@ -205,10 +203,23 @@ impl BackendModel for LLamaEdgeApiServer {
(old_model.wasm_module.clone(), listen_addr)
} else {
let addr = std::env::var("MOLY_API_SERVER_ADDR").unwrap_or("localhost:0".to_string());
let new_addr = std::net::TcpListener::bind(&addr)
.unwrap()
.local_addr()
.unwrap();

let listen_addr = load_model_options
.override_server_address
.clone()
.map(|addr| match std::net::TcpListener::bind(&addr) {
Ok(listener) => Some(listener.local_addr().unwrap()),
Err(_) => None,
})
.flatten();

let new_addr = match listen_addr {
Some(addr) => addr,
None => {
let listener = std::net::TcpListener::bind(&addr).unwrap();
listener.local_addr().unwrap()
}
};

(Module::from_bytes(None, WASM).unwrap(), new_addr)
};
Expand Down
17 changes: 11 additions & 6 deletions moly-backend/src/backend_impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use moly_protocol::{

use crate::store::{
self,
model_cards::{ModelCard, ModelCardManager},
model_cards::{self, ModelCard, ModelCardManager},
ModelFileDownloader,
};

Expand Down Expand Up @@ -145,6 +145,7 @@ fn test_chat() {
context_overflow_policy: moly_protocol::protocol::ContextOverflowPolicy::StopAtLimit,
n_batch: Some(128),
n_ctx: Some(1024),
override_server_address: None,
},
tx,
);
Expand Down Expand Up @@ -220,6 +221,7 @@ fn test_chat_stop() {
rope_freq_scale: 0.0,
rope_freq_base: 0.0,
context_overflow_policy: moly_protocol::protocol::ContextOverflowPolicy::StopAtLimit,
override_server_address: None,
},
tx,
);
Expand Down Expand Up @@ -381,6 +383,7 @@ pub struct BackendImpl<Model: BackendModel> {
download_tx: tokio::sync::mpsc::UnboundedSender<(
store::models::Model,
store::download_files::DownloadedFile,
model_cards::RemoteFile,
Sender<anyhow::Result<FileDownloadResponse>>,
)>,
model: Option<Model>,
Expand Down Expand Up @@ -444,7 +447,7 @@ impl<Model: BackendModel + Send + 'static> BackendImpl<Model> {
{
let client = reqwest::Client::new();
let downloader =
ModelFileDownloader::new(client, sql_conn.clone(), control_tx.clone(), 0.1);
ModelFileDownloader::new(client, sql_conn.clone(), control_tx.clone(),model_indexs.country_code.clone(), 0.1);
async_rt.spawn(ModelFileDownloader::run_loop(
downloader,
max_download_threads.max(3),
Expand Down Expand Up @@ -526,7 +529,7 @@ impl<Model: BackendModel + Send + 'static> BackendImpl<Model> {
}
ModelManagementCommand::DownloadFile(file_id, tx) => {
//search model from remote
let mut search_model_from_remote = || -> anyhow::Result<( crate::store::models::Model , crate::store::download_files::DownloadedFile)> {
let mut search_model_from_remote = || -> anyhow::Result<( crate::store::models::Model , crate::store::download_files::DownloadedFile,crate::store::model_cards::RemoteFile)> {
let (model_id, file) = file_id
.split_once("#")
.ok_or_else(|| anyhow::anyhow!("Illegal file_id"))?;
Expand All @@ -541,6 +544,8 @@ impl<Model: BackendModel + Send + 'static> BackendImpl<Model> {
.find(|f| f.name == file)
.ok_or_else(|| anyhow::anyhow!("file not found"))?;

let remote_file_ = remote_file.clone();

let download_model = crate::store::models::Model {
id: Arc::new(remote_model.id),
name: remote_model.name,
Expand Down Expand Up @@ -578,12 +583,12 @@ impl<Model: BackendModel + Send + 'static> BackendImpl<Model> {
sha256: remote_file.sha256.unwrap_or_default(),
};

Ok((download_model,download_file))
Ok((download_model,download_file,remote_file_))
};

match search_model_from_remote() {
Ok((model, file)) => {
let _ = self.download_tx.send((model, file, tx));
Ok((model, file,remote_file)) => {
let _ = self.download_tx.send((model, file,remote_file, tx));
}
Err(e) => {
let _ = tx.send(Err(e));
Expand Down
34 changes: 21 additions & 13 deletions moly-backend/src/store/model_cards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,26 @@ pub struct IpResult {
country_code: String,
}

fn get_model_cards_repo() -> String {
fn get_model_cards_repo() -> (String, String) {
let repo_url = std::env::var("MODEL_CARDS_REPO");
match repo_url {
Ok(url) => url,
Ok(url) => (url, ModelCardManager::DEFAULT_COUNTRY_CODE.to_string()),
Err(_) => {
match reqwest::blocking::get("http://ip-api.com/json")
.and_then(|r| r.json::<IpResult>())
{
Ok(ip_result) if ip_result.country_code == "CN" => {
"https://gitcode.com/xun_csh/model-cards".to_string()
}
_ => "https://github.com/moxin-org/model-cards".to_string(),
Ok(ip_result) if ip_result.country_code.to_ascii_uppercase() == "CN" => (
"https://gitcode.com/xun_csh/model-cards".to_string(),
"CN".to_string(),
),
Ok(ip_result) => (
"https://github.com/moxin-org/model-cards".to_string(),
ip_result.country_code.to_ascii_uppercase(),
),
_ => (
"https://github.com/moxin-org/model-cards".to_string(),
ModelCardManager::DEFAULT_COUNTRY_CODE.to_string(),
),
}
}
}
Expand All @@ -232,7 +240,7 @@ fn get_model_cards_repo() -> String {
pub static REPO_NAME: &'static str = "model-cards";

pub fn sync_model_cards_repo<P: AsRef<Path>>(app_data_dir: P) -> anyhow::Result<ModelCardManager> {
let repo_url = get_model_cards_repo();
let (repo_url, country_code) = get_model_cards_repo();
log::info!("Using model_cards repo: {}", repo_url);
let repo_dirs = app_data_dir.as_ref().join(REPO_NAME);

Expand Down Expand Up @@ -292,6 +300,7 @@ pub fn sync_model_cards_repo<P: AsRef<Path>>(app_data_dir: P) -> anyhow::Result<
Ok(ModelCardManager {
app_data_dir: app_data_dir.as_ref().to_path_buf(),
embedding_index,
country_code,
indexs,
caches: HashMap::new(),
})
Expand Down Expand Up @@ -398,13 +407,8 @@ pub struct RemoteFile {
pub tags: Vec<String>,
#[serde(default)]
pub sha256: Option<String>,
pub download: DownloadUrls,
}

#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct DownloadUrls {
#[serde(default)]
pub default: String,
pub download: HashMap<String, String>,
}

impl ModelCard {
Expand Down Expand Up @@ -484,6 +488,7 @@ impl ModelCard {

pub struct ModelCardManager {
app_data_dir: PathBuf,
pub country_code: String,
embedding_index: EmbeddingState,
indexs: HashMap<String, ModelIndex>,
caches: HashMap<String, ModelCard>,
Expand All @@ -495,11 +500,14 @@ pub enum EmbeddingState {
}

impl ModelCardManager {
const DEFAULT_COUNTRY_CODE: &'static str = "default";

pub fn empty(app_data_dir: PathBuf) -> Self {
Self {
app_data_dir,
indexs: HashMap::new(),
caches: HashMap::new(),
country_code: Self::DEFAULT_COUNTRY_CODE.to_string(),
embedding_index: EmbeddingState::Finish(None),
}
}
Expand Down
42 changes: 32 additions & 10 deletions moly-backend/src/store/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,34 +101,53 @@ pub struct ModelFileDownloader {
client: reqwest::Client,
sql_conn: Arc<Mutex<rusqlite::Connection>>,
control_tx: tokio::sync::broadcast::Sender<DownloadControlCommand>,
country_code: String,
step: f64,
}

impl ModelFileDownloader {
const DEFAULT_COUNTRY_CODE: &'static str = "default";
pub fn new(
client: reqwest::Client,
sql_conn: Arc<Mutex<rusqlite::Connection>>,
control_tx: tokio::sync::broadcast::Sender<DownloadControlCommand>,
country_code: String,
step: f64,
) -> Self {
Self {
client,
sql_conn,
control_tx,
country_code,
step,
}
}

fn get_download_url(&self, file: &super::download_files::DownloadedFile) -> String {
format!(
"https://huggingface.co/{}/resolve/main/{}",
file.model_id, file.name
)
fn get_download_url(
&self,
file: &super::download_files::DownloadedFile,
remote_file: &super::model_cards::RemoteFile,
) -> String {
remote_file
.download
.get(&self.country_code)
.cloned()
.unwrap_or_else(|| {
remote_file
.download
.get(Self::DEFAULT_COUNTRY_CODE)
.cloned()
.unwrap_or(format!(
"https://huggingface.co/{}/resolve/main/{}",
file.model_id, file.name
))
})
}

async fn download(
self,
file: super::download_files::DownloadedFile,
remote_file: super::model_cards::RemoteFile,
tx: Sender<anyhow::Result<FileDownloadResponse>>,
) {
let file_id = file.id.to_string();
Expand All @@ -143,7 +162,7 @@ impl ModelFileDownloader {
};

let r = self
.download_file_from_remote(file, &mut send_progress)
.download_file_from_remote(file, remote_file, &mut send_progress)
.await;

match r {
Expand All @@ -165,13 +184,15 @@ impl ModelFileDownloader {
mut download_rx: tokio::sync::mpsc::UnboundedReceiver<(
super::models::Model,
super::download_files::DownloadedFile,
super::model_cards::RemoteFile,
Sender<anyhow::Result<FileDownloadResponse>>,
)>,
) {
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_downloader));

while let Some((model, mut file, tx)) = download_rx.recv().await {
let url = downloader.get_download_url(&file);
while let Some((model, mut file, remote_file, tx)) = download_rx.recv().await {
let url = downloader.get_download_url(&file, &remote_file);
log::info!("Downloading file: {}", url);

let f = async {
let content_length = get_file_content_length(&downloader.client, &url)
Expand Down Expand Up @@ -200,7 +221,7 @@ impl ModelFileDownloader {
let semaphore_ = semaphore.clone();
tokio::spawn(async move {
let permit = semaphore_.acquire_owned().await.unwrap();
downloader_.download(file, tx).await;
downloader_.download(file, remote_file, tx).await;
drop(permit);
});
}
Expand All @@ -209,9 +230,10 @@ impl ModelFileDownloader {
async fn download_file_from_remote(
&self,
mut file: super::download_files::DownloadedFile,
remote_file: super::model_cards::RemoteFile,
report_fn: &mut (dyn FnMut(f64) -> anyhow::Result<()> + Send),
) -> anyhow::Result<Option<FileDownloadResponse>> {
let url = self.get_download_url(&file);
let url = self.get_download_url(&file, &remote_file);

let local_path = Path::new(&file.download_dir)
.join(&file.model_id)
Expand Down

0 comments on commit 267ade1

Please sign in to comment.