Improved layer idx parsing

PiperOrigin-RevId: 788868522
This commit is contained in:
Jan Wassenberg 2025-07-30 05:49:13 -07:00 committed by Copybara-Service
parent d831ddce5b
commit 8715eda512
1 changed files with 7 additions and 6 deletions

View File

@ -24,6 +24,7 @@
#include <cstdlib>
#include <cstring> // strcmp
#include <string>
#include <system_error> // std::errc // NOLINT
#include "compression/types.h"
#include "gemma/configs.h" // ModelConfig, kMaxQKVDim
@ -174,12 +175,12 @@ class TypePrefix {
// Returns 0 if the blob does not seem to be a per-layer tensor, otherwise the
// layer index.
static size_t LayerIdxFromKey(const std::string& key) {
const auto parse_num = [&key](size_t begin, size_t end) -> size_t {
const auto parse_num = [&key](size_t begin, size_t end) -> int {
HWY_DASSERT(begin <= end);
HWY_DASSERT(end <= key.size());
size_t val = 0;
(void)std::from_chars(key.data() + begin, key.data() + end, val);
return val;
int val = 0;
auto [ptr, ec] = std::from_chars(key.data() + begin, key.data() + end, val);
return (ec == std::errc()) ? val : -1;
};
const size_t suffix_pos = key.rfind('_');
@ -187,10 +188,10 @@ static size_t LayerIdxFromKey(const std::string& key) {
if (suffix_pos == std::string::npos) return 0;
if (suffix_pos == key.size() - 1) return 0;
size_t layer_idx = parse_num(suffix_pos + 1, key.size());
int layer_idx = parse_num(suffix_pos + 1, key.size());
HWY_ASSERT(layer_idx < 999);
return layer_idx;
return layer_idx == -1 ? 0 : static_cast<size_t>(layer_idx);
}
// Returns the number of layers based on the largest blob name suffix seen.