mirror of https://github.com/google/gemma.cpp.git
Update layer index parsing and allow tokenizer override
PiperOrigin-RevId: 788797948
This commit is contained in:
parent
d1638587f0
commit
d22ba2ac96
|
|
@ -20,6 +20,7 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <charconv>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring> // strcmp
|
#include <cstring> // strcmp
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -71,8 +72,15 @@ static std::string ReadTokenizer(BlobReader& reader,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read actual tokenizer from blob.
|
||||||
if (!tokenizer.empty() && tokenizer != kMockTokenizer) {
|
if (!tokenizer.empty() && tokenizer != kMockTokenizer) {
|
||||||
return tokenizer; // Read actual tokenizer from blob.
|
if (!tokenizer_path.Empty()) {
|
||||||
|
HWY_WARN("--weights has tokenizer but overriding with %s.",
|
||||||
|
tokenizer_path.path.c_str());
|
||||||
|
return ReadFileToString(tokenizer_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
// No blob but user specified path to file: read it or abort.
|
// No blob but user specified path to file: read it or abort.
|
||||||
|
|
@ -163,22 +171,42 @@ class TypePrefix {
|
||||||
std::array<size_t, kNumTypes> blobs_{0};
|
std::array<size_t, kNumTypes> blobs_{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns 0 if the blob does not seem to be a per-layer tensor, otherwise the
|
||||||
|
// layer index.
|
||||||
|
static size_t LayerIdxFromKey(const std::string& key) {
|
||||||
|
const auto parse_num = [&key](size_t begin, size_t end) -> size_t {
|
||||||
|
HWY_DASSERT(begin <= end);
|
||||||
|
HWY_DASSERT(end <= key.size());
|
||||||
|
size_t val = 0;
|
||||||
|
(void)std::from_chars(key.data() + begin, key.data() + end, val);
|
||||||
|
return val;
|
||||||
|
};
|
||||||
|
|
||||||
|
const size_t suffix_pos = key.rfind('_');
|
||||||
|
// If there is no digit after the last underscore, it is not a layer name.
|
||||||
|
if (suffix_pos == std::string::npos) return 0;
|
||||||
|
if (suffix_pos == key.size() - 1) return 0;
|
||||||
|
|
||||||
|
size_t layer_idx = parse_num(suffix_pos + 1, key.size());
|
||||||
|
|
||||||
|
HWY_ASSERT(layer_idx < 999);
|
||||||
|
return layer_idx;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the number of layers based on the largest blob name suffix seen.
|
// Returns the number of layers based on the largest blob name suffix seen.
|
||||||
// This works with or without type prefixes because it searches for suffixes.
|
// This works with or without type prefixes because it searches for suffixes.
|
||||||
static size_t DeduceNumLayers(const KeyVec& keys) {
|
static size_t DeduceNumLayers(const KeyVec& keys) {
|
||||||
|
// Built-in self-test.
|
||||||
|
{
|
||||||
|
HWY_ASSERT(LayerIdxFromKey("gr_conv_w_2") == 2); // common case
|
||||||
|
HWY_ASSERT(LayerIdxFromKey("prefix_") == 0); // no number
|
||||||
|
HWY_ASSERT(LayerIdxFromKey("c_embedding") == 0); // per-model
|
||||||
|
HWY_ASSERT(LayerIdxFromKey("c_final_norm") == 0); // per-model, two _
|
||||||
|
}
|
||||||
|
|
||||||
size_t max_layer_idx = 0;
|
size_t max_layer_idx = 0;
|
||||||
for (const std::string& key : keys) {
|
for (const std::string& key : keys) {
|
||||||
const size_t suffix_pos = key.rfind('_');
|
max_layer_idx = HWY_MAX(max_layer_idx, LayerIdxFromKey(key));
|
||||||
if (suffix_pos == std::string::npos) continue;
|
|
||||||
|
|
||||||
char* end;
|
|
||||||
auto layer_idx = strtoul(key.c_str() + suffix_pos + 1, &end, 10); // NOLINT
|
|
||||||
HWY_ASSERT(layer_idx < 999); // Also checks for `ULONG_MAX` if out of range
|
|
||||||
// Ignore if not a suffix. Some names are prefixed with "c_" for historical
|
|
||||||
// reasons. In such cases, parsing layer_idx anyway returns 0.
|
|
||||||
if (end - key.c_str() != key.size()) continue;
|
|
||||||
|
|
||||||
max_layer_idx = HWY_MAX(max_layer_idx, layer_idx);
|
|
||||||
}
|
}
|
||||||
return max_layer_idx + 1;
|
return max_layer_idx + 1;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue