diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8625439366..cbeedf6c4b 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -228,13 +228,41 @@ struct gguf_context { }; struct gguf_reader { - FILE * file; + gguf_reader(FILE * file) : file(file) { + // read the remaining bytes once and update on each read + nbytes_remain = file_remain(file); + } - gguf_reader(FILE * file) : file(file) {} + // helper for remaining bytes in a file + static uint64_t file_remain(FILE * file) { + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return 0; + } + if (gguf_fseek(file, 0, SEEK_END) != 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + const int64_t end = gguf_ftell(file); + if (end < 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + gguf_fseek(file, cur, SEEK_SET); + return static_cast(end - cur); + } template bool read(T & dst) const { - return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + const size_t size = sizeof(dst); + if (nbytes_remain < size) { + return false; + } + const size_t nread = fread(&dst, 1, size, file); + nbytes_remain -= nread; + return nread == size; } template @@ -242,20 +270,19 @@ struct gguf_reader { if (n > GGUF_MAX_ARRAY_ELEMENTS) { return false; } - const uint64_t nbytes = nbytes_remain(); if constexpr (std::is_same::value) { // strings are prefixed with their length, so we need to account for that if (n > SIZE_MAX / sizeof(uint64_t)) { return false; } - if (nbytes < n * sizeof(uint64_t)) { + if (nbytes_remain < n * sizeof(uint64_t)) { return false; } } else { if (n > SIZE_MAX / sizeof(T)) { return false; } - if (nbytes < n * sizeof(T)) { + if (nbytes_remain < n * sizeof(T)) { return false; } } @@ -312,39 +339,29 @@ struct gguf_reader { GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH); return false; } - const uint64_t nbytes = nbytes_remain(); - if (size > nbytes) { - GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes); + if (size > nbytes_remain) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain); return false; } dst.resize(static_cast(size)); - return fread(dst.data(), 1, dst.length(), file) == dst.length(); + const size_t nread = fread(dst.data(), 1, size, file); + nbytes_remain -= nread; + return nread == size; } bool read(void * dst, const size_t size) const { - return fread(dst, 1, size, file) == size; + if (size > nbytes_remain) { + return false; + } + const size_t nread = fread(dst, 1, size, file); + nbytes_remain -= nread; + return nread == size; } - // remaining bytes in the file - uint64_t nbytes_remain() const { - const int64_t cur = gguf_ftell(file); - if (cur < 0) { - return 0; - } - if (gguf_fseek(file, 0, SEEK_END) != 0) { - gguf_fseek(file, cur, SEEK_SET); +private: + FILE * file; - return 0; - } - const int64_t end = gguf_ftell(file); - if (end < 0) { - gguf_fseek(file, cur, SEEK_SET); - - return 0; - } - gguf_fseek(file, cur, SEEK_SET); - return static_cast(end - cur); - } + mutable uint64_t nbytes_remain; }; struct gguf_context * gguf_init_empty(void) {