Compare commits

...

3 Commits

Author SHA1 Message Date
Olivier Chafik 8843a98c2b
Improve usability of --model-url & related flags (#6930)
* args: default --model to models/ + filename from --model-url or --hf-file (or else legacy models/7B/ggml-model-f16.gguf)

* args: main & server now call gpt_params_handle_model_default

* args: define DEFAULT_MODEL_PATH + update cli docs

* curl: check url of previous download (.json metadata w/ url, etag & lastModified)

* args: fix update to quantize-stats.cpp

* curl: support legacy .etag / .lastModified companion files

* curl: rm legacy .etag file support

* curl: reuse regex across headers callback calls

* curl: unique_ptr to manage lifecycle of curl & outfile

* curl: nit: no need for multiline regex flag

* curl: update failed test (model file collision) + gitignore *.gguf.json
2024-04-30 00:52:50 +01:00
Clint Herron b8c1476e44
Extending grammar integration tests (#6644)
* Cleaning up integration tests to share code between tests and make it simpler to add new tests.

* Add tests around quantifiers to ensure both matching and non-matching compliance.

* Add slightly more complex grammar with quantifiers to test references with quantifiers.

* Fixing build when C++17 is not present.

* Separating test calls to give more helpful stack traces on failure. Adding verbose messages to give visibility for what is being tested.

* Adding quotes around strings to explicitly show whitespace

* Removing trailing whitespace.

* Implementing suggestions from @ochafik -- grammars and test strings now print and flush before tests to aid in debugging segfaults and whatnot.

* Cleaning up forgotten symbols. Modifying simple test to use test harness. Added comments for more verbose descriptions of what each test is accomplishing.

* Unicode symbol modifications to hopefully make log easier to parse visually.
2024-04-29 14:40:14 -04:00
Daniel Bevenius 5539e6fdd1
main : fix typo in comment in main.cpp (#6985)
Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
2024-04-29 13:56:59 -04:00
9 changed files with 372 additions and 272 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.a
*.so
*.gguf
*.gguf.json
*.bin
*.exe
*.dll

View File

@ -67,7 +67,6 @@
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
#define LLAMA_CURL_MAX_HEADER_LENGTH 256
#endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json;
@ -1324,6 +1323,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return false;
}
void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (params.hf_file.empty()) {
if (params.model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
}
params.hf_file = params.model;
} else if (params.model.empty()) {
params.model = "models/" + string_split(params.hf_file, '/').back();
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
auto f = string_split(params.model_url, '#').front();
f = string_split(f, '?').front();
f = string_split(f, '/').back();
params.model = "models/" + f;
}
} else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH;
}
}
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false;
std::string arg;
@ -1352,10 +1374,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// short-hand to avoid specifying --hf-file -> default it to --model
if (!params.hf_repo.empty() && params.hf_file.empty()) {
params.hf_file = params.model;
}
gpt_params_handle_model_default(params);
if (params.escape) {
process_escapes(params.prompt);
@ -1548,7 +1567,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --control-vector-layer-range START END\n");
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH);
printf(" -md FNAME, --model-draft FNAME\n");
printf(" draft model for speculative decoding (default: unused)\n");
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
@ -1896,59 +1915,75 @@ void llama_batch_add(
#ifdef LLAMA_USE_CURL
static bool llama_download_file(CURL * curl, const char * url, const char * path) {
static bool starts_with(const std::string & str, const std::string & prefix) {
// While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
}
static bool llama_download_file(const std::string & url, const std::string & path) {
// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return false;
}
bool force_download = false;
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl, CURLOPT_URL, url);
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
struct stat model_file_info;
auto file_exists = (stat(path, &model_file_info) == 0);
auto file_exists = (stat(path.c_str(), &model_file_info) == 0);
// If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files
char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char etag_path[PATH_MAX] = {0};
snprintf(etag_path, sizeof(etag_path), "%s.etag", path);
char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char last_modified_path[PATH_MAX] = {0};
snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata;
std::string etag;
std::string last_modified;
if (file_exists) {
auto * f_etag = fopen(etag_path, "r");
if (f_etag) {
if (!fgets(etag, sizeof(etag), f_etag)) {
fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path);
} else {
fprintf(stderr, "%s: previous file found %s: %s\n", __func__, etag_path, etag);
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata["url"].is_string()) {
auto previous_url = metadata["url"].get<std::string>();
if (previous_url != url) {
fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false;
}
}
if (metadata.contains("etag") && metadata["etag"].is_string()) {
etag = metadata["etag"];
}
if (metadata.contains("lastModified") && metadata["lastModified"].is_string()) {
last_modified = metadata["lastModified"];
}
} catch (const nlohmann::json::exception & e) {
fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
return false;
}
fclose(f_etag);
}
auto * f_last_modified = fopen(last_modified_path, "r");
if (f_last_modified) {
if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) {
fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path);
} else {
fprintf(stderr, "%s: previous file found %s: %s\n", __func__, last_modified_path,
last_modified);
}
fclose(f_last_modified);
}
} else {
fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct llama_load_model_from_url_headers {
char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
std::string etag;
std::string last_modified;
};
llama_load_model_from_url_headers headers;
{
@ -1956,38 +1991,37 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
// Convert header field name to lowercase
for (size_t i = 0; i < n_items && buffer[i] != ':'; ++i) {
buffer[i] = tolower(buffer[i]);
}
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
const char * etag_prefix = "etag: ";
if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) {
strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF
}
const char * last_modified_prefix = "last-modified: ";
if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) {
strncpy(headers->last_modified, buffer + strlen(last_modified_prefix),
n_items - strlen(last_modified_prefix) - 2); // Remove CRLF
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers);
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
CURLcode res = curl_easy_perform(curl);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
return false;
}
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
@ -1996,28 +2030,30 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
}
}
// If the ETag or the Last-Modified headers are different: trigger a new download
bool should_download = !file_exists
|| force_download
|| (strlen(headers.etag) > 0 && strcmp(etag, headers.etag) != 0)
|| (strlen(headers.last_modified) > 0 && strcmp(last_modified, headers.last_modified) != 0);
bool should_download = !file_exists || force_download;
if (!should_download) {
if (!etag.empty() && etag != headers.etag) {
fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
if (should_download) {
char path_temporary[PATH_MAX] = {0};
snprintf(path_temporary, sizeof(path_temporary), "%s.downloadInProgress", path);
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path);
if (remove(path) != 0) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path);
fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
// Set the output file
auto * outfile = fopen(path_temporary, "wb");
std::unique_ptr<FILE, decltype(&fclose)> outfile(fopen(path_temporary.c_str(), "wb"), fclose);
if (!outfile) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path);
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
@ -2025,12 +2061,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile);
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
@ -2049,51 +2085,34 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
// start the download
fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path, headers.etag, headers.last_modified);
auto res = curl_easy_perform(curl);
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
auto res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
fclose(outfile);
curl_easy_cleanup(curl);
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
return false;
}
long http_code = 0;
curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
fclose(outfile);
curl_easy_cleanup(curl);
fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
// Clean up
fclose(outfile);
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the new ETag to the .etag file
if (strlen(headers.etag) > 0) {
auto * etag_file = fopen(etag_path, "w");
if (etag_file) {
fputs(headers.etag, etag_file);
fclose(etag_file);
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
}
}
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
std::ofstream(metadata_path) << metadata.dump(4);
fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
// Write the new lastModified to the .etag file
if (strlen(headers.last_modified) > 0) {
auto * last_modified_file = fopen(last_modified_path, "w");
if (last_modified_file) {
fputs(headers.last_modified, last_modified_file);
fclose(last_modified_file);
fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path,
headers.last_modified);
}
}
if (rename(path_temporary, path) != 0) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
}
@ -2111,15 +2130,7 @@ struct llama_model * llama_load_model_from_url(
return NULL;
}
// Initialize libcurl
auto * curl = curl_easy_init();
if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
}
if (!llama_download_file(curl, model_url, path_model)) {
if (!llama_download_file(model_url, path_model)) {
return NULL;
}
@ -2133,7 +2144,6 @@ struct llama_model * llama_load_model_from_url(
auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
if (!ctx_gguf) {
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model);
curl_easy_cleanup(curl);
return NULL;
}
@ -2145,8 +2155,6 @@ struct llama_model * llama_load_model_from_url(
gguf_free(ctx_gguf);
}
curl_easy_cleanup(curl);
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
@ -2177,11 +2185,7 @@ struct llama_model * llama_load_model_from_url(
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
auto * curl = curl_easy_init();
bool res = llama_download_file(curl, split_url, split_path);
curl_easy_cleanup(curl);
return res;
return llama_download_file(split_url, split_path);
}, idx));
}
@ -2668,7 +2672,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);

View File

@ -31,6 +31,8 @@
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
} while(0)
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
// build info
extern int LLAMA_BUILD_NUMBER;
extern char const *LLAMA_COMMIT;
@ -92,7 +94,7 @@ struct gpt_params {
// // sampling parameters
struct llama_sampling_params sparams;
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model = ""; // model path
std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias
std::string model_url = ""; // model url to download
@ -171,6 +173,8 @@ struct gpt_params {
std::vector<std::string> image; // path to image file(s)
};
void gpt_params_handle_model_default(gpt_params & params);
bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);

View File

@ -66,7 +66,7 @@ main.exe -m models\7B\ggml-model.bin --ignore-eos -n -1 --random-prompt
In this section, we cover the most commonly used options for running the `main` program with the LLaMA models:
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`; inferred from `--model-url` if set).
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf).
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
- `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models.

View File

@ -324,7 +324,7 @@ int main(int argc, char ** argv) {
log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size());
// if we will use the cache for the full prompt without reaching the end of the cache, force
// reevaluation of the last token token to recalculate the cached logits
// reevaluation of the last token to recalculate the cached logits
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);

View File

@ -23,7 +23,7 @@
#endif
struct quantize_stats_params {
std::string model = "models/7B/ggml-model-f16.gguf";
std::string model = DEFAULT_MODEL_PATH;
bool verbose = false;
bool per_layer_stats = false;
bool print_histogram = false;

View File

@ -2353,7 +2353,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" disable KV offload\n");
}
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH);
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: unused)\n");
printf(" -hfr REPO, --hf-repo REPO\n");
@ -2835,6 +2835,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
}
}
gpt_params_handle_model_default(params);
if (!params.kv_overrides.empty()) {
params.kv_overrides.emplace_back();
params.kv_overrides.back().key[0] = 0;

View File

@ -5,7 +5,7 @@ Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf
And a model file ggml-model-f16.gguf
And a model file bert-bge-small.gguf
And a model alias bert-bge-small
And 42 as server seed
And 2 slots

View File

@ -10,15 +10,10 @@
#include "unicode.h"
#include <cassert>
#include <string>
#include <vector>
static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
@ -30,8 +25,10 @@ number ::= [0-9]+)""";
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
std::string input = "123+456";
return grammar;
}
static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});
const auto & code_points = decoded.first;
@ -39,159 +36,67 @@ number ::= [0-9]+)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
if (grammar->stacks.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}
assert(completed_grammar);
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
// An empty stack means that the grammar has been completed
return true;
}
}
// Clean up allocated memory
llama_grammar_free(grammar);
return false;
}
static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings
const std::string grammar_str = R"""(root ::= expression
expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""";
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr);
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
auto grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
// Test a few strings
std::vector<std::string> test_strings_pass = {
"42",
"1*2*3*4*5",
"x",
"x+10",
"x1+y2",
"(a+b)*(c-d)",
"func()",
"func(x,y+2)",
"a*(b+c)-d/e",
"f(g(x),h(y,z))",
"x + 10",
"x1 + y2",
"(a + b) * (c - d)",
"func()",
"func(x, y + 2)",
"a * (b + c) - d / e",
"f(g(x), h(y, z))",
"123+456",
"123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
};
std::vector<std::string> test_strings_fail = {
"+",
"/ 3x",
"x + + y",
"a * / b",
"func(,)",
"func(x y)",
"(a + b",
"x + y)",
"a + b * (c - d",
"42 +",
"x +",
"x + 10 +",
"(a + b) * (c - d",
"func(",
"func(x, y + 2",
"a * (b + c) - d /",
"f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
};
fprintf(stderr, " 🔵 Valid strings:\n");
// Passing strings
for (const auto & test_string : test_strings_pass) {
auto decoded = decode_utf8(test_string, {});
for (const auto & test_string : passing_strings) {
fprintf(stderr, " \"%s\" ", test_string.c_str());
fflush(stderr);
const auto & code_points = decoded.first;
bool matched = match_string(test_string, grammar);
int pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
if (!matched) {
fprintf(stderr, "❌ (failed to match)\n");
} else {
fprintf(stdout, "✅︎\n");
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
assert(completed_grammar);
assert(matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
fprintf(stderr, " 🟠 Invalid strings:\n");
// Failing strings
for (const auto & test_string : test_strings_fail) {
auto decoded = decode_utf8(test_string, {});
for (const auto & test_string : failing_strings) {
fprintf(stderr, " \"%s\" ", test_string.c_str());
fflush(stderr);
const auto & code_points = decoded.first;
bool parse_failed = false;
bool matched = match_string(test_string, grammar);
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
}
assert(!grammar->stacks.empty());
if (matched) {
fprintf(stderr, "❌ (incorrectly matched)\n");
} else {
fprintf(stdout, "✅︎\n");
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);
assert(!matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
@ -201,7 +106,183 @@ ws ::= [ \t\n\r]?)""";
llama_grammar_free(grammar);
}
static void test_simple_grammar() {
// Test case for a simple grammar
test_grammar(
"simple grammar",
R"""(
root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""",
// Passing strings
{
"42",
"1+2+3+4+5",
"123+456",
},
// Failing strings
{
"+",
"/ 3",
"1+2+3+4+5+",
"12a45",
}
);
}
static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings
test_grammar(
"medium complexity grammar",
// Grammar
R"""(
root ::= expression
expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""",
// Passing strings
{
"42",
"1*2*3*4*5",
"x",
"x+10",
"x1+y2",
"(a+b)*(c-d)",
"func()",
"func(x,y+2)",
"a*(b+c)-d/e",
"f(g(x),h(y,z))",
"x + 10",
"x1 + y2",
"(a + b) * (c - d)",
"func()",
"func(x, y + 2)",
"a * (b + c) - d / e",
"f(g(x), h(y, z))",
"123+456",
"123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
},
// Failing strings
{
"+",
"/ 3x",
"x + + y",
"a * / b",
"func(,)",
"func(x y)",
"(a + b",
"x + y)",
"a + b * (c - d",
"42 +",
"x +",
"x + 10 +",
"(a + b) * (c - d",
"func(",
"func(x, y + 2",
"a * (b + c) - d /",
"f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
}
);
}
static void test_quantifiers() {
// A collection of tests to exercise * + and ? quantifiers
test_grammar(
"* quantifier",
// Grammar
R"""(root ::= "a"*)""",
// Passing strings
{
"",
"a",
"aaaaa",
"aaaaaaaaaaaaaaaaaa",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
},
// Failing strings
{
"b",
"ab",
"aab",
"ba",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
}
);
test_grammar(
"+ quantifier",
// Grammar
R"""(root ::= "a"+)""",
// Passing strings
{
"a",
"aaaaa",
"aaaaaaaaaaaaaaaaaa",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
},
// Failing strings
{
"",
"b",
"ab",
"aab",
"ba",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
}
);
test_grammar(
"? quantifier",
// Grammar
R"""(root ::= "a"?)""",
// Passing strings
{
"",
"a"
},
// Failing strings
{
"b",
"ab",
"aa",
"ba",
}
);
test_grammar(
"mixed quantifiers",
// Grammar
R"""(
root ::= cons+ vowel* cons? (vowel cons)*
vowel ::= [aeiouy]
cons ::= [bcdfghjklmnpqrstvwxyz]
)""",
// Passing strings
{
"yes",
"no",
"noyes",
"crwth",
"four",
"bryyyy",
},
// Failing strings
{
"yess",
"yesno",
"forty",
"catyyy",
}
);
}
static void test_failure_missing_root() {
fprintf(stderr, "⚫ Testing missing root node:\n");
// Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr
expr ::= term ("+" term)*
@ -215,29 +296,37 @@ number ::= [0-9]+)""";
// Ensure we do NOT have a root node
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
fprintf(stderr, " ✅︎ Passed\n");
}
static void test_failure_missing_reference() {
fprintf(stderr, "⚫ Testing missing reference node:\n");
// Test case for a grammar that is missing a referenced rule
const std::string grammar_str = R"""(root ::= expr
const std::string grammar_str =
R"""(root ::= expr
expr ::= term ("+" term)*
term ::= numero
number ::= [0-9]+)""";
fprintf(stderr, "Expected error: ");
fprintf(stderr, " Expected error: ");
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());
fprintf(stderr, "End of expected error. Test successful.\n");
fprintf(stderr, " End of expected error.\n");
fprintf(stderr, " ✅︎ Passed\n");
}
int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
test_complex_grammar();
test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
fprintf(stdout, "All tests passed.\n");
return 0;
}