Skip to content

Commit

Permalink
llama: llama_split_prefix fix strncpy does not include string termina…
Browse files Browse the repository at this point in the history
…tion

common: llama_load_model_from_url:
 - fix header name case sensitive
 - support downloading additional split in parallel
 - hide password in url
  • Loading branch information
phymbert committed Mar 23, 2024
1 parent 50ccaf5 commit ddb13ed
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 48 deletions.
192 changes: 150 additions & 42 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#endif
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#include <thread>
#include <future>
#endif

#if defined(_MSC_VER)
Expand Down Expand Up @@ -1702,27 +1705,27 @@ void llama_batch_add(

#ifdef LLAMA_USE_CURL

struct llama_model * llama_load_model_from_url(
const char * model_url,
const char * path_model,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
return NULL;
static std::string llama_download_hide_password_in_url(const std::string & url) {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}

// Initialize libcurl globally
auto curl = curl_easy_init();

if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}

return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
}

static bool llama_download_file(CURL * curl, const char * url, const char * path) {
bool force_download = false;

// Set the URL, allow to follow http redirection
curl_easy_setopt(curl, CURLOPT_URL, model_url);
curl_easy_setopt(curl, CURLOPT_URL, url);
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);

#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
Expand All @@ -1731,24 +1734,24 @@ struct llama_model * llama_load_model_from_url(

// Check if the file already exists locally
struct stat model_file_info;
auto file_exists = (stat(path_model, &model_file_info) == 0);
auto file_exists = (stat(path, &model_file_info) == 0);

// If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files
char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char etag_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0};
snprintf(etag_path, sizeof(etag_path), "%s.etag", path_model);
char etag_path[PATH_MAX] = {0};
snprintf(etag_path, sizeof(etag_path), "%s.etag", path);

char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char last_modified_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0};
snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path_model);
char last_modified_path[PATH_MAX] = {0};
snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path);

if (file_exists) {
auto * f_etag = fopen(etag_path, "r");
if (f_etag) {
if (!fgets(etag, sizeof(etag), f_etag)) {
fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path);
} else {
fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, etag_path, etag);
fprintf(stderr, "%s: previous file found %s: %s\n", __func__, etag_path, etag);
}
fclose(f_etag);
}
Expand All @@ -1758,7 +1761,7 @@ struct llama_model * llama_load_model_from_url(
if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) {
fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path);
} else {
fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, last_modified_path,
fprintf(stderr, "%s: previous file found %s: %s\n", __func__, last_modified_path,
last_modified);
}
fclose(f_last_modified);
Expand All @@ -1776,6 +1779,11 @@ struct llama_model * llama_load_model_from_url(
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;

// Convert header field name to lowercase
for(size_t i = 0; i < n_items && buffer[i] != ':'; ++i) {
buffer[i] = tolower(buffer[i]);
}

const char * etag_prefix = "etag: ";
if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) {
strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF
Expand All @@ -1798,38 +1806,42 @@ struct llama_model * llama_load_model_from_url(
if (res != CURLE_OK) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
return NULL;
return false;
}

long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
file_exists = false;
force_download = true;
fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
}
}

// If the ETag or the Last-Modified headers are different: trigger a new download
if (!file_exists || strcmp(etag, headers.etag) != 0 || strcmp(last_modified, headers.last_modified) != 0) {
char path_model_temporary[LLAMA_CURL_MAX_PATH_LENGTH] = {0};
snprintf(path_model_temporary, sizeof(path_model_temporary), "%s.downloadInProgress", path_model);
bool should_download = !file_exists
|| force_download
|| (strlen(headers.etag) > 0 && strcmp(etag, headers.etag) != 0)
|| (strlen(headers.last_modified) > 0 && strcmp(last_modified, headers.last_modified) != 0);
if (should_download) {
char path_temporary[PATH_MAX] = {0};
snprintf(path_temporary, sizeof(path_temporary), "%s.downloadInProgress", path);
if (file_exists) {
fprintf(stderr, "%s: deleting previous downloaded model file: %s\n", __func__, path_model);
if (remove(path_model) != 0) {
fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path);
if (remove(path) != 0) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path_model);
return NULL;
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path);
return false;
}
}

// Set the output file
auto * outfile = fopen(path_model_temporary, "wb");
auto * outfile = fopen(path_temporary, "wb");
if (!outfile) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model);
return NULL;
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path);
return false;
}

typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
Expand All @@ -1844,14 +1856,14 @@ struct llama_model * llama_load_model_from_url(
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);

// start the download
fprintf(stderr, "%s: downloading model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
model_url, path_model, headers.etag, headers.last_modified);
fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path, headers.etag, headers.last_modified);
auto res = curl_easy_perform(curl);
if (res != CURLE_OK) {
fclose(outfile);
curl_easy_cleanup(curl);
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
return NULL;
return false;
}

long http_code = 0;
Expand All @@ -1860,7 +1872,7 @@ struct llama_model * llama_load_model_from_url(
fclose(outfile);
curl_easy_cleanup(curl);
fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
return NULL;
return false;
}

// Clean up
Expand All @@ -1872,7 +1884,7 @@ struct llama_model * llama_load_model_from_url(
if (etag_file) {
fputs(headers.etag, etag_file);
fclose(etag_file);
fprintf(stderr, "%s: model etag saved %s: %s\n", __func__, etag_path, headers.etag);
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
}
}

Expand All @@ -1882,20 +1894,116 @@ struct llama_model * llama_load_model_from_url(
if (last_modified_file) {
fputs(headers.last_modified, last_modified_file);
fclose(last_modified_file);
fprintf(stderr, "%s: model last modified saved %s: %s\n", __func__, last_modified_path,
fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path,
headers.last_modified);
}
}

if (rename(path_model_temporary, path_model) != 0) {
if (rename(path_temporary, path) != 0) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_model_temporary, path_model);
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
return false;
}
}

return true;
}

struct llama_model * llama_load_model_from_url(
const char * model_url,
const char * path_model,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
return NULL;
}

// Initialize libcurl
auto * curl = curl_easy_init();

if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
}

if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
}

if (!llama_download_file(curl, model_url, path_model)) {
return NULL;
}

// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
if (!ctx_gguf) {
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model);
return NULL;
}

auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}

gguf_free(ctx_gguf);
}

curl_easy_cleanup(curl);

if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[PATH_MAX] = {0};

// Verify the first split file format
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) {
fprintf(stderr, "\n%s: unexpected model file name: %s"
" n_split=%d\n", __func__, path_model, n_split);
return NULL;
}

if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) {
fprintf(stderr, "\n%s: unexpected model url: %s"
" n_split=%d\n", __func__, model_url, n_split);
return NULL;
}
}

// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);

char split_url[PATH_MAX] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);

auto * curl = curl_easy_init();
bool res = llama_download_file(curl, split_url, split_path);
curl_easy_cleanup(curl);

return res;
}, idx));
}

// Wait for all downloads to complete
for(auto &f : futures_download) {
if(!f.get()) {
return NULL;
}
}
}

return llama_load_model_from_file(path_model, params);
}

Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,10 @@ struct llama_control_vector_load_info {
// Load control vectors, scale each by strength, and add them together.
// On error, returns {-1, empty}
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);

//
// Split utils
//
static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
4 changes: 0 additions & 4 deletions examples/gguf-split/gguf-split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ enum split_operation : uint8_t {
SPLIT_OP_MERGE,
};

static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";

struct split_params {
split_operation operation = SPLIT_OP_SPLIT;
int n_split_tensors = 128;
Expand Down
4 changes: 2 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14839,9 +14839,9 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
std::string str_postfix(postfix);

// check if dest ends with postfix
int size_prefix = str_split_path.size() - str_postfix.size();
int size_prefix = str_split_path.size() - str_postfix.size() ;
if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
snprintf(dest, std::min((size_t) size_prefix, maxlen), "%s", split_path);
snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
return size_prefix;
}

Expand Down

0 comments on commit ddb13ed

Please sign in to comment.