diff --git a/include/llama.h b/include/llama.h index 1507107f1a..c3360ae57c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1476,12 +1476,12 @@ extern "C" { /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. - LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/src/llama.cpp b/src/llama.cpp index 11b75fcff9..6da90d6f1f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1095,25 +1095,55 @@ int32_t llama_chat_apply_template( // model split // -int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { +int32_t llama_split_path( + char * split_path, + size_t maxlen, + const char * path_prefix, + int32_t split_no, + int32_t split_count) { + static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; - if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { - return strlen(split_path); + + const int written = snprintf( + split_path, + maxlen, + SPLIT_PATH_FORMAT, + path_prefix, + split_no + 1, + split_count + ); + + if (written < 0 || (size_t) written >= maxlen) { + return 0; } - return 0; + + return (int32_t) written; } -int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) { - std::string str_split_path(split_path); - char postfix[32]; - snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count); - std::string str_postfix(postfix); +int32_t llama_split_prefix( + char * split_prefix, + size_t maxlen, + const char * split_path, + int32_t split_no, + int32_t split_count) { - // check if split_prefix ends with postfix - int size_prefix = str_split_path.size() - str_postfix.size(); - if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) { - snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path); - return size_prefix; + const std::string str_split_path(split_path); + + char postfix[32]; + snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count); + + const std::string str_postfix(postfix); + if (str_split_path.size() <= str_postfix.size()) { + return 0; + } + + const size_t size_prefix = str_split_path.size() - str_postfix.size(); + + if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) { + const size_t copy_len = std::min(size_prefix + 1, maxlen); + snprintf(split_prefix, copy_len, "%s", split_path); + + return (int32_t) size_prefix; } return 0;