mirror of https://github.com/google/gemma.cpp.git
Add blob_path to config deduction message
PiperOrigin-RevId: 782188689
This commit is contained in:
parent
349c86f2d9
commit
56c9196eb6
|
|
@ -132,6 +132,7 @@ cc_library(
|
|||
deps = [
|
||||
":basics",
|
||||
"//compression:types",
|
||||
"//io",
|
||||
"//io:fields",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "compression/types.h" // Type
|
||||
#include "io/fields.h" // IFields
|
||||
#include "io/io.h" // Path
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -708,7 +709,7 @@ bool ModelConfig::OverwriteWithCanonical() {
|
|||
return found;
|
||||
}
|
||||
|
||||
Model DeduceModel(size_t layers, int layer_types) {
|
||||
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
||||
switch (layers) {
|
||||
case 2:
|
||||
return Model::GEMMA_TINY;
|
||||
|
|
@ -740,8 +741,8 @@ Model DeduceModel(size_t layers, int layer_types) {
|
|||
return Model::PALIGEMMA2_772M_224;
|
||||
*/
|
||||
default:
|
||||
HWY_WARN("Failed to deduce model type from layer count %zu types %x.",
|
||||
layers, layer_types);
|
||||
HWY_WARN("Failed to deduce model type from %s, layer count %zu types %x.",
|
||||
blob_path.path.c_str(), layers, layer_types);
|
||||
return Model::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
|
||||
#include "compression/types.h" // Type
|
||||
#include "io/fields.h" // IFieldsVisitor
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -482,7 +483,7 @@ enum DeducedLayerTypes {
|
|||
};
|
||||
|
||||
// layer_types is one or more of `DeducedLayerTypes`.
|
||||
Model DeduceModel(size_t layers, int layer_types);
|
||||
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
|||
// Always deduce so we can verify it against the config we read.
|
||||
const size_t layers = DeduceNumLayers(reader.Keys());
|
||||
const int layer_types = DeduceLayerTypes(reader);
|
||||
const Model deduced_model = DeduceModel(layers, layer_types);
|
||||
const Model deduced_model =
|
||||
DeduceModel(reader.blob_path(), layers, layer_types);
|
||||
|
||||
ModelConfig config;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
|
|
|
|||
|
|
@ -297,7 +297,9 @@ class BlobStore {
|
|||
}; // BlobStore
|
||||
|
||||
BlobReader::BlobReader(const Path& blob_path)
|
||||
: file_(OpenFileOrAbort(blob_path, "r")), file_bytes_(file_->FileSize()) {
|
||||
: blob_path_(blob_path),
|
||||
file_(OpenFileOrAbort(blob_path, "r")),
|
||||
file_bytes_(file_->FileSize()) {
|
||||
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
|
||||
|
||||
BlobStore bs(*file_);
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ class BlobReader {
|
|||
// Aborts on error.
|
||||
explicit BlobReader(const Path& blob_path);
|
||||
|
||||
const Path& blob_path() const { return blob_path_; }
|
||||
|
||||
// Non-const version required for File::Map().
|
||||
File& file() { return *file_; }
|
||||
const File& file() const { return *file_; }
|
||||
|
|
@ -101,6 +103,7 @@ class BlobReader {
|
|||
}
|
||||
|
||||
private:
|
||||
Path blob_path_;
|
||||
std::unique_ptr<File> file_;
|
||||
const uint64_t file_bytes_;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue