mirror of https://github.com/google/gemma.cpp.git
Huge refactor of weight handling and model loading.
Weight handling: - new ModelStore2 supports both pre-2025 multi-file and single-file formats - simpler ForEachTensor with TensorArgs - tensors are constructed with their full suffixed name I/O: - support mmap and stride - Simplified SbsWriter, single insert(); add SbsReader Misc: - kMockTokenizer: allow creating with unavailable tokenizer - configs.h: Simpler enum validity checks via kSentinel - matmul.h: remove unused enable_bind (now in allocator.h) - tensor_info: single TensorInfoRegistry class, rename from tensor_index.h Frontends: - Replace Allocate/CreateGemma with ctor(LoaderArgs, MatMulEnv&) - Deduce model/weight type, remove --model and parsing - Replace most common.h includes with configs.h - Remove --compressed_weights, use --weights instead - Remove ModelInfo, replaced by ModelConfig. Backprop: - Reduce max loss, remove backward_scalar_test (timeout) - Update thresholds because new RandInit changes rng eval order and thus numerics PiperOrigin-RevId: 755317484
This commit is contained in:
parent
a3caf6e5d2
commit
8d0882b966
284
BUILD.bazel
284
BUILD.bazel
|
|
@ -126,17 +126,9 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
"gemma/common.cc",
|
||||
"gemma/configs.cc",
|
||||
"gemma/tensor_index.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/tensor_index.h",
|
||||
],
|
||||
name = "configs",
|
||||
srcs = ["gemma/configs.cc"],
|
||||
hdrs = ["gemma/configs.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:fields",
|
||||
|
|
@ -149,23 +141,21 @@ cc_test(
|
|||
name = "configs_test",
|
||||
srcs = ["gemma/configs_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
":configs",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"//compression:fields",
|
||||
"//compression:shared",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_index_test",
|
||||
srcs = ["gemma/tensor_index_test.cc"],
|
||||
cc_library(
|
||||
name = "tensor_info",
|
||||
srcs = ["gemma/tensor_info.cc"],
|
||||
hdrs = ["gemma/tensor_info.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":common",
|
||||
":mat",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@highway//:hwy", # aligned_allocator.h
|
||||
":configs",
|
||||
"//compression:shared",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -176,7 +166,7 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":common",
|
||||
":tensor_info",
|
||||
":threading_context",
|
||||
"//compression:fields",
|
||||
"//compression:shared",
|
||||
|
|
@ -186,6 +176,82 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer",
|
||||
srcs = ["gemma/tokenizer.cc"],
|
||||
hdrs = ["gemma/tokenizer.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_store",
|
||||
srcs = ["gemma/model_store.cc"],
|
||||
hdrs = ["gemma/model_store.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
":mat",
|
||||
":tensor_info",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"//compression:blob_store",
|
||||
"//compression:fields",
|
||||
"//compression:io",
|
||||
"//compression:shared",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
":mat",
|
||||
":model_store",
|
||||
":tensor_info",
|
||||
"//compression:blob_store",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_info_test",
|
||||
srcs = ["gemma/tensor_info_test.cc"],
|
||||
deps = [
|
||||
":configs",
|
||||
":mat",
|
||||
":tensor_info",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@highway//:hwy", # aligned_allocator.h
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = ["gemma/common.cc"],
|
||||
hdrs = ["gemma/common.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":configs",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
)
|
||||
|
||||
# For building all tests in one command, so we can test several.
|
||||
test_suite(
|
||||
name = "ops_tests",
|
||||
|
|
@ -343,43 +409,24 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":mat",
|
||||
"//compression:blob_store",
|
||||
"//compression:compress",
|
||||
"//compression:io", # Path
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer",
|
||||
srcs = ["gemma/tokenizer.cc"],
|
||||
hdrs = ["gemma/tokenizer.h"],
|
||||
deps = [
|
||||
":common",
|
||||
"//compression:io", # Path
|
||||
"//compression:shared",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kv_cache",
|
||||
srcs = ["gemma/kv_cache.cc"],
|
||||
hdrs = ["gemma/kv_cache.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":configs",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_args",
|
||||
hdrs = ["gemma/gemma_args.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":ops", # matmul.h
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
@ -409,8 +456,11 @@ cc_library(
|
|||
":allocator",
|
||||
":basics",
|
||||
":common",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":model_store",
|
||||
":ops",
|
||||
":tokenizer",
|
||||
":threading",
|
||||
|
|
@ -428,6 +478,36 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cross_entropy",
|
||||
srcs = ["evals/cross_entropy.cc"],
|
||||
hdrs = ["evals/cross_entropy.h"],
|
||||
deps = [
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_helper",
|
||||
srcs = ["evals/benchmark_helper.cc"],
|
||||
hdrs = ["evals/benchmark_helper.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
":cross_entropy",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"@google_benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_shared_lib",
|
||||
srcs = [
|
||||
|
|
@ -459,51 +539,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cross_entropy",
|
||||
srcs = ["evals/cross_entropy.cc"],
|
||||
hdrs = ["evals/cross_entropy.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_args",
|
||||
hdrs = ["gemma/gemma_args.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
"//compression:io",
|
||||
"//compression:shared",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_helper",
|
||||
srcs = ["evals/benchmark_helper.cc"],
|
||||
hdrs = ["evals/benchmark_helper.h"],
|
||||
deps = [
|
||||
":cross_entropy",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"@google_benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gemma_test",
|
||||
srcs = ["evals/gemma_test.cc"],
|
||||
|
|
@ -516,7 +551,7 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
|
|
@ -535,7 +570,7 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
|
|
@ -549,11 +584,9 @@ cc_binary(
|
|||
deps = [
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"//compression:shared",
|
||||
"//paligemma:image",
|
||||
|
|
@ -568,7 +601,6 @@ cc_binary(
|
|||
deps = [
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":cross_entropy",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
|
|
@ -578,12 +610,6 @@ cc_binary(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_prompts",
|
||||
hdrs = ["evals/prompts.h"],
|
||||
deps = ["@highway//:hwy"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "benchmarks",
|
||||
srcs = [
|
||||
|
|
@ -592,7 +618,6 @@ cc_binary(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":benchmark_prompts",
|
||||
"@google_benchmark//:benchmark",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
|
|
@ -600,9 +625,7 @@ cc_binary(
|
|||
|
||||
cc_binary(
|
||||
name = "debug_prompt",
|
||||
srcs = [
|
||||
"evals/debug_prompt.cc",
|
||||
],
|
||||
srcs = ["evals/debug_prompt.cc"],
|
||||
deps = [
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
|
|
@ -623,7 +646,6 @@ cc_binary(
|
|||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
|
@ -660,6 +682,7 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":configs",
|
||||
":mat",
|
||||
":ops",
|
||||
":prompt",
|
||||
|
|
@ -680,6 +703,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":common",
|
||||
":configs",
|
||||
":mat",
|
||||
":prompt",
|
||||
":weights",
|
||||
|
|
@ -687,26 +711,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "backward_scalar_test",
|
||||
size = "large",
|
||||
srcs = [
|
||||
"backprop/backward_scalar_test.cc",
|
||||
"backprop/test_util.h",
|
||||
],
|
||||
deps = [
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":mat",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "backward_test",
|
||||
size = "large",
|
||||
|
|
@ -721,7 +725,7 @@ cc_test(
|
|||
deps = [
|
||||
":backprop",
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":configs",
|
||||
":mat",
|
||||
":ops",
|
||||
":prompt",
|
||||
|
|
@ -741,11 +745,8 @@ cc_library(
|
|||
hdrs = ["backprop/optimizer.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":mat",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"//compression:shared",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
@ -762,13 +763,14 @@ cc_test(
|
|||
":allocator",
|
||||
":backprop",
|
||||
":basics",
|
||||
":common",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":optimizer",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading",
|
||||
":tokenizer",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:shared",
|
||||
|
|
|
|||
|
|
@ -86,8 +86,10 @@ set(SOURCES
|
|||
gemma/instantiations/sfp.cc
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/tensor_index.cc
|
||||
gemma/tensor_index.h
|
||||
gemma/model_store.cc
|
||||
gemma/model_store.h
|
||||
gemma/tensor_info.cc
|
||||
gemma/tensor_info.h
|
||||
gemma/tokenizer.cc
|
||||
gemma/tokenizer.h
|
||||
gemma/weights.cc
|
||||
|
|
@ -196,7 +198,6 @@ enable_testing()
|
|||
include(GoogleTest)
|
||||
|
||||
set(GEMMA_TEST_FILES
|
||||
backprop/backward_scalar_test.cc
|
||||
backprop/backward_test.cc
|
||||
backprop/optimize_test.cc
|
||||
compression/blob_store_test.cc
|
||||
|
|
@ -206,7 +207,7 @@ set(GEMMA_TEST_FILES
|
|||
compression/nuq_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
gemma/tensor_index_test.cc
|
||||
gemma/tensor_info_test.cc
|
||||
ops/bench_matmul.cc
|
||||
ops/dot_test.cc
|
||||
ops/gemma_matvec_test.cc
|
||||
|
|
|
|||
|
|
@ -96,9 +96,10 @@ https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_
|
|||
From Pytorch, use the following script to generate uncompressed weights:
|
||||
https://github.com/google/gemma.cpp/blob/dev/compression/convert_weights.py
|
||||
|
||||
Then run `compression/compress_weights.cc` (Bazel target
|
||||
`compression:compress_weights`), specifying the resulting file as `--weights`
|
||||
and the desired .sbs name as the `--compressed_weights`.
|
||||
For PaliGemma, use `python/convert_from_safetensors` to create an SBS file
|
||||
directly.
|
||||
|
||||
For other models, `gemma_export_main.py` is not yet open sourced.
|
||||
|
||||
## Compile-Time Flags (Advanced)
|
||||
|
||||
|
|
|
|||
|
|
@ -453,9 +453,8 @@ $ ./gemma [...]
|
|||
|___/ |_| |_|
|
||||
|
||||
tokenizer : tokenizer.spm
|
||||
compressed_weights : 2b-it-sfp.sbs
|
||||
weights : 2b-it-sfp.sbs
|
||||
model : 2b-it
|
||||
weights : [no path specified]
|
||||
max_generated_tokens : 2048
|
||||
|
||||
*Usage*
|
||||
|
|
|
|||
|
|
@ -1,634 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/backward_scalar.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // memcpy
|
||||
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(BackPropTest, MatMulVJP) {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto weights = MakePacked<T>("weights", kRows, kCols);
|
||||
auto x = MakePacked<T>("x", kTokens, kCols);
|
||||
auto grad = MakePacked<T>("grad", kRows, kCols);
|
||||
auto dx = MakePacked<T>("dx", kTokens, kCols);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
auto dy = MakePacked<T>("dy", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
|
||||
kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
ZeroInit(grad);
|
||||
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||
dx.Packed(), kRows, kCols, kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-14, 1e-11, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||
static const size_t kRows = 2;
|
||||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto weights = MakePacked<T>("weights", kRows, kCols * kHeads);
|
||||
auto x = MakePacked<T>("x", kTokens, kCols * kHeads);
|
||||
auto grad = MakePacked<T>("grad", kRows, kCols * kHeads);
|
||||
auto dx = MakePacked<T>("dx", kTokens, kCols * kHeads);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
auto dy = MakePacked<T>("dy", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
|
||||
kRows, kCols, kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
ZeroInit(grad);
|
||||
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||
grad.Packed(), dx.Packed(), kHeads, kRows, kCols,
|
||||
kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, RMSNormVJP) {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto weights = MakePacked<T>("weights", N, 1);
|
||||
auto grad = MakePacked<T>("grad", N, 1);
|
||||
auto x = MakePacked<T>("x", K, N);
|
||||
auto dx = MakePacked<T>("dx", K, N);
|
||||
auto dy = MakePacked<T>("dy", K, N);
|
||||
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||
};
|
||||
ZeroInit(grad);
|
||||
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||
dx.Packed(), N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, SoftmaxVJP) {
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto x = MakePacked<T>("x", N, 1);
|
||||
auto dx = MakePacked<T>("dx", N, 1);
|
||||
auto dy = MakePacked<T>("dy", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
CopyMat(c_x, c_y);
|
||||
Softmax(c_y.Packed(), N);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||
};
|
||||
Softmax(x.Packed(), N);
|
||||
CopyMat(dy, dx);
|
||||
SoftmaxVJPT(x.Packed(), dx.Packed(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MaskedSoftmaxVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t N = kTokens * kHeads * kSeqLen;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto x = MakePacked<T>("x", N, 1);
|
||||
auto dy = MakePacked<T>("dy", N, 1);
|
||||
auto dx = MakePacked<T>("dx", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||
ZeroInit(dx);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
CopyMat(c_x, c_y);
|
||||
MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||
};
|
||||
MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen);
|
||||
CopyMat(dy, dx);
|
||||
MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, SoftcapVJP) {
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto x = MakePacked<T>("x", N, 1);
|
||||
auto dx = MakePacked<T>("dx", N, 1);
|
||||
auto dy = MakePacked<T>("dy", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||
|
||||
constexpr float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
CopyMat(c_x, c_y);
|
||||
Softcap(kCap, c_y.Packed(), N);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||
};
|
||||
Softcap(kCap, x.Packed(), N);
|
||||
CopyMat(dy, dx);
|
||||
SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||
static const size_t K = 8;
|
||||
static const size_t V = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto x = MakePacked<T>("x", K, V);
|
||||
auto dx = MakePacked<T>("dx", K, V);
|
||||
auto c_x = MakePacked<TC>("c_x", K, V);
|
||||
Prompt prompt;
|
||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||
|
||||
const float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
prompt.context_size = 1 + (iter % 6);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Softcap(kCap, x.Packed(), V * K);
|
||||
Softmax(x.Packed(), V, K);
|
||||
CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); };
|
||||
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, GatedGeluVJP) {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto x = MakePacked<T>("x", K, 2 * N);
|
||||
auto dx = MakePacked<T>("dx", K, 2 * N);
|
||||
auto dy = MakePacked<T>("dy", K, N);
|
||||
auto c_x = MakePacked<TC>("c_x", K, 2 * N);
|
||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0f, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
GatedGelu(c_x.Packed(), c_y.Packed(), N, K);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N * K);
|
||||
};
|
||||
GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MaskedAttentionVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kQKVDim = 8;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
|
||||
static const size_t kOutSize = kTokens * kHeads * kSeqLen;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto x = MakePacked<T>("x", kQKVSize, 1);
|
||||
auto dx = MakePacked<T>("dx", kQKVSize, 1);
|
||||
auto dy = MakePacked<T>("dy", kOutSize, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", kQKVSize, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
|
||||
ZeroInit(dx);
|
||||
ZeroInit(c_y);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0f, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
|
||||
};
|
||||
MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads,
|
||||
kQKVDim, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MixByAttentionVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kQKVDim = 8;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
|
||||
static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen;
|
||||
static const size_t kOutSize = kSeqLen * kHeads * kQKVDim;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto qkv = MakePacked<T>("qkv", kQKVSize, 1);
|
||||
auto dqkv = MakePacked<T>("dqkv", kQKVSize, 1);
|
||||
auto attn = MakePacked<T>("attn", kAttnSize, 1);
|
||||
auto dattn = MakePacked<T>("dattn", kAttnSize, 1);
|
||||
auto dy = MakePacked<T>("dy", kOutSize, 1);
|
||||
auto c_qkv = MakePacked<TC>("c_qkv", kQKVSize, 1);
|
||||
auto c_attn = MakePacked<TC>("c_attn", kAttnSize, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
|
||||
ZeroInit(dqkv);
|
||||
ZeroInit(dattn);
|
||||
ZeroInit(c_y);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(qkv, 1.0f, gen);
|
||||
RandInit(attn, 1.0f, gen);
|
||||
Complexify(qkv, c_qkv);
|
||||
Complexify(attn, c_attn);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
auto func = [&]() {
|
||||
MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens,
|
||||
kHeads, kQKVDim, kSeqLen);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
|
||||
};
|
||||
MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(),
|
||||
dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, InputEmbeddingVJP) {
|
||||
static const size_t kSeqLen = 8;
|
||||
static const size_t kVocabSize = 4;
|
||||
static const size_t kModelDim = 16;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
auto weights = MakePacked<T>("weights", kVocabSize, kModelDim);
|
||||
auto grad = MakePacked<T>("grad", kVocabSize, kModelDim);
|
||||
auto dy = MakePacked<T>("dy", kSeqLen, kModelDim);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kVocabSize, kModelDim);
|
||||
auto c_y = MakePacked<TC>("c_y", kSeqLen, kModelDim);
|
||||
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
|
||||
size_t num_tokens = tokens.size() - 1;
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f, gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
auto func = [&]() {
|
||||
InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(),
|
||||
kModelDim);
|
||||
return DotT(dy.Packed(), c_y.Packed(), num_tokens * kModelDim);
|
||||
};
|
||||
ZeroInit(grad);
|
||||
InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(),
|
||||
grad.Packed(), kModelDim);
|
||||
TestGradient(grad, c_weights, func, 1e-14, 1e-14, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
static ModelConfig TestConfig() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 12;
|
||||
config.seq_len = 18;
|
||||
LayerConfig layer_config;
|
||||
layer_config.model_dim = config.model_dim;
|
||||
layer_config.ff_hidden_dim = 48;
|
||||
layer_config.heads = 3;
|
||||
layer_config.kv_heads = 1;
|
||||
layer_config.qkv_dim = 12;
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
TEST(BackPropTest, LayerVJP) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
ModelConfig config = TestConfig();
|
||||
const TensorIndex tensor_index = TensorIndexLLM(config, size_t{0});
|
||||
const size_t kOutputSize = config.seq_len * config.model_dim;
|
||||
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
|
||||
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
|
||||
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
||||
auto y = MakePacked<T>("y", kOutputSize, 1);
|
||||
auto dy = MakePacked<T>("dy", kOutputSize, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", kOutputSize, 1);
|
||||
const size_t num_tokens = 3;
|
||||
std::vector<MatOwner> layer_storage;
|
||||
weights.Allocate(layer_storage);
|
||||
grad.Allocate(layer_storage);
|
||||
c_weights.Allocate(layer_storage);
|
||||
ZeroInit(backward.input);
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0, gen);
|
||||
RandInit(forward.input, 1.0, gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(forward.input, c_forward.input);
|
||||
auto func = [&]() {
|
||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.Packed());
|
||||
return DotT(dy.Packed(), c_y.Packed(), num_tokens * config.model_dim);
|
||||
};
|
||||
grad.ZeroInit(/*layer_idx=*/0);
|
||||
ApplyLayer(weights, forward, num_tokens, y.Packed());
|
||||
LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens);
|
||||
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11, __LINE__,
|
||||
__LINE__);
|
||||
TestGradient(grad, c_weights, func, 2e-11, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, EndToEnd) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<T> weights(config);
|
||||
WeightsWrapper<T> grad(config);
|
||||
ForwardPass<T> forward(config);
|
||||
ForwardPass<T> backward(config);
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0, gen);
|
||||
CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
grad.ZeroInit();
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward, grad.get(), backward);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 1e-11, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const LayerWeightsPtrs<T>& x,
|
||||
LayerWeightsPtrs<T>& out) {
|
||||
MulByConstAndAddT(c, x.pre_attention_norm_scale,
|
||||
out.pre_attention_norm_scale);
|
||||
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
|
||||
MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w);
|
||||
MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale);
|
||||
MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w);
|
||||
MulByConstAndAddT(c, x.linear_w, out.linear_w);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const ModelWeightsPtrs<T>& x,
|
||||
ModelWeightsPtrs<T>& out) {
|
||||
const size_t layers = x.c_layers.size();
|
||||
MulByConstAndAddT(c, x.embedder_input_embedding,
|
||||
out.embedder_input_embedding);
|
||||
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
|
||||
for (size_t i = 0; i < layers; ++i) {
|
||||
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch.
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<T>& weights,
|
||||
ForwardPass<T>& forward) {
|
||||
T loss = 0.0;
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
}
|
||||
T scale = 1.0 / batch.size();
|
||||
return loss * scale;
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch by applying gradient with the given
|
||||
// learning rate. Does not update weights, but uses the given tmp weights
|
||||
// instead.
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<T>& weights,
|
||||
const WeightsWrapper<T>& grad,
|
||||
WeightsWrapper<T>& tmp, ForwardPass<T>& forward) {
|
||||
tmp.CopyFrom(weights);
|
||||
const T scale = -learning_rate / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), tmp.get());
|
||||
return CrossEntropyLossForwardPass(batch, tmp, forward);
|
||||
}
|
||||
|
||||
// Uses line search in the negative gradient direction to update weights. We do
|
||||
// this so that we can test that each step during the gradient descent can
|
||||
// decrease the objective function value.
|
||||
template <typename T>
|
||||
T FindOptimalUpdate(const WeightsWrapper<T>& grad, WeightsWrapper<T>& weights,
|
||||
WeightsWrapper<T>& tmp, ForwardPass<T>& forward,
|
||||
const std::vector<Prompt>& batch, T loss,
|
||||
T initial_learning_rate) {
|
||||
T lr0 = initial_learning_rate;
|
||||
T loss0 = CrossEntropyLossForwardPass(
|
||||
lr0, batch, weights, grad, tmp, forward);
|
||||
for (size_t iter = 0; iter < 30; ++iter) {
|
||||
T lr1 = lr0 * 0.5;
|
||||
T loss1 = CrossEntropyLossForwardPass(
|
||||
lr1, batch, weights, grad, tmp, forward);
|
||||
if (loss0 < loss && loss1 >= loss0) {
|
||||
break;
|
||||
}
|
||||
loss0 = loss1;
|
||||
lr0 = lr1;
|
||||
}
|
||||
for (size_t iter = 0; iter < 30; ++iter) {
|
||||
T lr1 = lr0 * 2.0;
|
||||
T loss1 = CrossEntropyLossForwardPass(
|
||||
lr1, batch, weights, grad, tmp, forward);
|
||||
if (loss1 >= loss0) {
|
||||
break;
|
||||
}
|
||||
loss0 = loss1;
|
||||
lr0 = lr1;
|
||||
}
|
||||
const T scale = -lr0 / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), weights.get());
|
||||
return lr0;
|
||||
}
|
||||
|
||||
TEST(BackProptest, Convergence) {
|
||||
std::mt19937 gen(42);
|
||||
using T = float;
|
||||
using TC = std::complex<double>;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<T> weights(config);
|
||||
WeightsWrapper<T> grad(config);
|
||||
WeightsWrapper<T> tmp(config);
|
||||
ForwardPass<T> forward(config);
|
||||
ForwardPass<T> backward(config);
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
constexpr size_t kBatchSize = 5;
|
||||
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
|
||||
T learning_rate = 0.01;
|
||||
|
||||
RandInit(weights.get(), T(1.0), gen);
|
||||
|
||||
printf("Sample batch:\n");
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
ReverseSequenceSampler::LogPrompt(training_task.Sample(gen));
|
||||
}
|
||||
|
||||
T prev_loss = std::numeric_limits<T>::max();
|
||||
bool stop = false;
|
||||
size_t step = 0;
|
||||
while (!stop) {
|
||||
T loss = 0.0;
|
||||
grad.ZeroInit();
|
||||
std::mt19937 sgen(42);
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(kBatchSize, sgen);
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward, grad.get(), backward);
|
||||
}
|
||||
|
||||
if (step % 250 == 0) {
|
||||
printf("Checking gradient...\n");
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
TC scale = batch.size();
|
||||
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 5e-3f, __LINE__);
|
||||
}
|
||||
|
||||
loss /= batch.size();
|
||||
EXPECT_LT(loss, prev_loss);
|
||||
stop = step >= 1000 || loss < T{1.0};
|
||||
if (step % 10 == 0 || stop) {
|
||||
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
|
||||
step, loss, learning_rate);
|
||||
}
|
||||
if (!stop) {
|
||||
learning_rate = FindOptimalUpdate(
|
||||
grad, weights, tmp, forward, batch, loss, learning_rate);
|
||||
++step;
|
||||
}
|
||||
prev_loss = loss;
|
||||
}
|
||||
EXPECT_LT(step, 1000);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -25,9 +25,8 @@
|
|||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/backward_scalar.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/common_scalar.h" // DotT
|
||||
#include "backprop/forward_scalar.h" // MatMulT
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
|
|
@ -50,6 +49,14 @@
|
|||
#include "backprop/forward-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
// 'include guard' so we only define this once. Note that HWY_ONCE is only
|
||||
// defined during the last pass, but this is used in each pass.
|
||||
#ifndef BACKWARD_TEST_ONCE
|
||||
#define BACKWARD_TEST_ONCE
|
||||
// TestEndToEnd is slow, so only run it for the best-available target.
|
||||
static int run_once;
|
||||
#endif
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
|
@ -81,8 +88,6 @@ void TestMatMulVJP() {
|
|||
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||
auto grad = MakePacked<float>("grad", kRows, kCols);
|
||||
auto dx = MakePacked<float>("dx", kTokens, kCols);
|
||||
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols);
|
||||
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols);
|
||||
using TC = std::complex<double>;
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||
|
|
@ -105,12 +110,6 @@ void TestMatMulVJP() {
|
|||
grad.Packed(), dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
|
||||
ZeroInit(grad_scalar);
|
||||
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||
dx_scalar.Packed(), kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -126,8 +125,6 @@ void TestMultiHeadMatMulVJP() {
|
|||
auto grad = MakePacked<float>("grad", kRows, kCols * kHeads);
|
||||
auto dx = MakePacked<float>("dx", kTokens, kCols * kHeads);
|
||||
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols * kHeads);
|
||||
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols * kHeads);
|
||||
using TC = std::complex<double>;
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||
|
|
@ -150,13 +147,6 @@ void TestMultiHeadMatMulVJP() {
|
|||
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
|
||||
ZeroInit(grad_scalar);
|
||||
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||
grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -170,8 +160,6 @@ void TestRMSNormVJP() {
|
|||
auto grad = MakePacked<float>("grad", N, 1);
|
||||
auto dx = MakePacked<float>("dx", K, N);
|
||||
auto dy = MakePacked<float>("dy", K, N);
|
||||
auto grad_scalar = MakePacked<float>("grad_scalar", N, 1);
|
||||
auto dx_scalar = MakePacked<float>("dx_scalar", K, N);
|
||||
using TC = std::complex<double>;
|
||||
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||
|
|
@ -193,42 +181,15 @@ void TestRMSNormVJP() {
|
|||
dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
|
||||
ZeroInit(grad_scalar);
|
||||
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||
dx_scalar.Packed(), N, K);
|
||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__, __LINE__);
|
||||
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
static ModelConfig TestConfig() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 16;
|
||||
config.seq_len = 24;
|
||||
LayerConfig layer_config;
|
||||
layer_config.model_dim = config.model_dim;
|
||||
layer_config.ff_hidden_dim = 64;
|
||||
layer_config.heads = 3;
|
||||
layer_config.kv_heads = 1;
|
||||
layer_config.qkv_dim = 16;
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
void TestEndToEnd() {
|
||||
if (++run_once > 1) return; // ~3 min on SKX, only run best available target
|
||||
|
||||
std::mt19937 gen(42);
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
ModelConfig config = TestConfig();
|
||||
ModelConfig config(Model::GEMMA_TINY, Type::kF32, PromptWrapping::GEMMA_IT);
|
||||
WeightsWrapper<float> weights(config);
|
||||
WeightsWrapper<float> grad(config);
|
||||
ForwardPass<float> forward0(config);
|
||||
|
|
@ -246,7 +207,7 @@ void TestEndToEnd() {
|
|||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0f, gen);
|
||||
weights.get().RandInit(1.0f, gen);
|
||||
|
||||
float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
|
||||
|
||||
|
|
@ -256,7 +217,7 @@ void TestEndToEnd() {
|
|||
|
||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||
|
||||
grad.ZeroInit();
|
||||
grad.get().ZeroInit();
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||
backward, inv_timescale, pool);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@
|
|||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -51,16 +51,14 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
const ModelInfo info = {
|
||||
.model = Model::GEMMA_TINY,
|
||||
.wrapping = PromptWrapping::GEMMA_IT,
|
||||
.weight = Type::kF32,
|
||||
};
|
||||
ModelConfig config = ConfigFromModel(info.model);
|
||||
ModelWeightsStorage grad, grad_m, grad_v;
|
||||
grad.Allocate(info.model, info.weight, pool);
|
||||
grad_m.Allocate(info.model, info.weight, pool);
|
||||
grad_v.Allocate(info.model, info.weight, pool);
|
||||
ModelConfig config(Model::GEMMA_TINY, Type::kF32,
|
||||
ChooseWrapping(Model::GEMMA_TINY));
|
||||
config.eos_id = ReverseSequenceSampler::kEndToken;
|
||||
|
||||
WeightsOwner grad(Type::kF32), grad_m(Type::kF32), grad_v(Type::kF32);
|
||||
grad.AllocateForTest(config, pool);
|
||||
grad_m.AllocateForTest(config, pool);
|
||||
grad_v.AllocateForTest(config, pool);
|
||||
grad_m.ZeroInit();
|
||||
grad_v.ZeroInit();
|
||||
ForwardPass<float> forward(config), backward(config);
|
||||
|
|
@ -70,7 +68,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, env);
|
||||
Gemma gemma(config, GemmaTokenizer(kMockTokenizer), env);
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
|
|
@ -84,7 +82,6 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
.gen = &gen,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.eos_id = ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
|
|
@ -102,11 +99,11 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
reply.begin() + context.size());
|
||||
};
|
||||
|
||||
gemma.MutableWeights().RandInit(gen);
|
||||
gemma.MutableWeights().AllocAndCopyWithTranspose(pool);
|
||||
gemma.MutableWeights().RandInit(1.0f, gen);
|
||||
gemma.MutableWeights().Reshape(pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
gemma.MutableWeights().LogWeightStats();
|
||||
gemma.MutableWeights().LogWeightStatsF32();
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
constexpr float kAlpha = 0.001f;
|
||||
|
|
@ -128,29 +125,28 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
Prompt prompt = training_task.Sample(sgen);
|
||||
total_loss += CrossEntropyLossForwardPass(
|
||||
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
|
||||
inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
|
||||
*grad.GetWeightsOfType<float>(), backward, inv_timescale, pool);
|
||||
gemma.MutableWeights().CopyWithTranspose(pool);
|
||||
prompt, *gemma.Weights().GetF32(), forward, inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
|
||||
*grad.GetF32(), backward, inv_timescale,
|
||||
pool);
|
||||
gemma.MutableWeights().Reshape(pool);
|
||||
num_ok += verify(prompt) ? 1 : 0;
|
||||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
AdamUpdate(info.weight, grad, kAlpha, kBeta1, kBeta2, kEpsilon, steps + 1,
|
||||
AdamUpdate(grad, kAlpha, kBeta1, kBeta2, kEpsilon, steps + 1,
|
||||
gemma.Weights(), grad_m, grad_v, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
printf("Batch gradient:\n");
|
||||
grad.LogWeightStats();
|
||||
grad.LogWeightStatsF32();
|
||||
}
|
||||
if (total_loss < kMaxLoss) break; // Done
|
||||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
gemma.MutableWeights().LogWeightStats();
|
||||
gemma.MutableWeights().LogWeightStatsF32();
|
||||
EXPECT_LT(steps, 50);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,11 +17,9 @@
|
|||
|
||||
#include <cmath>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -29,37 +27,67 @@ namespace gcpp {
|
|||
|
||||
namespace {
|
||||
|
||||
class AdamUpdater {
|
||||
// Split into two classes so that ForEachTensor only requires two "other"
|
||||
// arguments. This is anyway useful for locality, because `grad` only feeds
|
||||
// into `grad_m` and `grad_v` here.
|
||||
class AdamUpdateMV {
|
||||
public:
|
||||
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
|
||||
size_t t)
|
||||
: alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1),
|
||||
cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
|
||||
AdamUpdateMV(float beta1, float beta2, size_t t)
|
||||
: beta1_(beta1),
|
||||
beta2_(beta2),
|
||||
cbeta1_(1.0f - beta1),
|
||||
cbeta2_(1.0f - beta2),
|
||||
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
|
||||
|
||||
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
|
||||
MatPtr& grad_m, MatPtr& grad_v) {
|
||||
const float* HWY_RESTRICT g = grad.RowT<float>(0);
|
||||
float* HWY_RESTRICT w = weights.RowT<float>(0);
|
||||
float* HWY_RESTRICT m = grad_m.RowT<float>(0);
|
||||
float* HWY_RESTRICT v = grad_v.RowT<float>(0);
|
||||
for (size_t i = 0; i < grad.Extents().Area(); ++i) {
|
||||
m[i] *= beta1_;
|
||||
m[i] += cbeta1_ * g[i];
|
||||
v[i] *= beta2_;
|
||||
v[i] += cbeta2_ * g[i] * g[i];
|
||||
const float mhat = m[i] * norm1_;
|
||||
const float vhat = v[i] * norm2_;
|
||||
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
void operator()(const MatPtr& grad, const MatPtr& grad_m,
|
||||
const MatPtr& grad_v) {
|
||||
for (size_t r = 0; r < grad.Rows(); ++r) {
|
||||
const float* HWY_RESTRICT g = grad.RowT<float>(r);
|
||||
float* HWY_RESTRICT m = grad_m.MutableRowT<float>(r);
|
||||
float* HWY_RESTRICT v = grad_v.MutableRowT<float>(r);
|
||||
for (size_t c = 0; c < grad.Cols(); ++c) {
|
||||
m[c] *= beta1_;
|
||||
m[c] += cbeta1_ * g[c];
|
||||
v[c] *= beta2_;
|
||||
v[c] += cbeta2_ * g[c] * g[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float beta1_;
|
||||
float beta2_;
|
||||
float cbeta1_;
|
||||
float cbeta2_;
|
||||
float norm1_;
|
||||
float norm2_;
|
||||
};
|
||||
|
||||
// Updates `weights` based on the updated `grad_m` and `grad_v` from above.
|
||||
class AdamUpdateW {
|
||||
public:
|
||||
AdamUpdateW(float alpha, float beta1, float beta2, float epsilon, size_t t)
|
||||
: alpha_(alpha),
|
||||
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
|
||||
epsilon_(epsilon) {}
|
||||
|
||||
void operator()(MatPtr& weights, const MatPtr& grad_m, const MatPtr& grad_v) {
|
||||
for (size_t r = 0; r < weights.Rows(); ++r) {
|
||||
float* HWY_RESTRICT w = weights.RowT<float>(r);
|
||||
const float* HWY_RESTRICT m = grad_m.RowT<float>(r);
|
||||
const float* HWY_RESTRICT v = grad_v.RowT<float>(r);
|
||||
for (size_t c = 0; c < weights.Cols(); ++c) {
|
||||
const float mhat = m[c] * norm1_;
|
||||
const float vhat = v[c] * norm2_;
|
||||
w[c] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta1_;
|
||||
float beta2_;
|
||||
float cbeta1_;
|
||||
float cbeta2_;
|
||||
float norm1_;
|
||||
float norm2_;
|
||||
float epsilon_;
|
||||
|
|
@ -70,26 +98,25 @@ void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
|
|||
ModelWeightsPtrs<float>* weights,
|
||||
ModelWeightsPtrs<float>* grad_m,
|
||||
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
||||
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc,
|
||||
[&updater](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]);
|
||||
});
|
||||
AdamUpdateMV update_mv(beta1, beta2, t);
|
||||
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
|
||||
update_mv(t.mat, *t.other_mat1, *t.other_mat2);
|
||||
});
|
||||
|
||||
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
|
||||
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
|
||||
update_w(t.mat, *t.other_mat1, *t.other_mat2);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
|
||||
float beta1, float beta2, float epsilon, size_t t,
|
||||
const ModelWeightsStorage& weights,
|
||||
const ModelWeightsStorage& grad_m,
|
||||
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(weight_type == Type::kF32);
|
||||
AdamUpdate(grad.GetWeightsOfType<float>(), alpha, beta1, beta2, epsilon, t,
|
||||
weights.GetWeightsOfType<float>(),
|
||||
grad_m.GetWeightsOfType<float>(), grad_v.GetWeightsOfType<float>(),
|
||||
pool);
|
||||
void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2,
|
||||
float epsilon, size_t t, const WeightsOwner& weights,
|
||||
const WeightsOwner& grad_m, const WeightsOwner& grad_v,
|
||||
hwy::ThreadPool& pool) {
|
||||
AdamUpdate(grad.GetF32(), alpha, beta1, beta2, epsilon, t, weights.GetF32(),
|
||||
grad_m.GetF32(), grad_v.GetF32(), pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,17 +16,17 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
|
||||
float beta1, float beta2, float epsilon, size_t t,
|
||||
const ModelWeightsStorage& weights,
|
||||
const ModelWeightsStorage& grad_m,
|
||||
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool);
|
||||
void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2,
|
||||
float epsilon, size_t t, const WeightsOwner& weights,
|
||||
const WeightsOwner& grad_m, const WeightsOwner& grad_v,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@
|
|||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gemma/configs.h"
|
||||
|
|
@ -32,27 +30,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// TODO: make a member of Layer<T>.
|
||||
template <typename T>
|
||||
void RandInit(LayerWeightsPtrs<T>& w, float stddev, std::mt19937& gen) {
|
||||
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
||||
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
||||
RandInit(w.qkv_einsum_w, stddev, gen);
|
||||
RandInit(w.pre_ffw_norm_scale, stddev, gen);
|
||||
RandInit(w.gating_einsum_w, stddev, gen);
|
||||
RandInit(w.linear_w, stddev, gen);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RandInit(ModelWeightsPtrs<T>& w, float stddev, std::mt19937& gen) {
|
||||
const size_t kLayers = w.c_layers.size();
|
||||
RandInit(w.embedder_input_embedding, stddev, gen);
|
||||
RandInit(w.final_norm_scale, stddev, gen);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
RandInit(*w.GetLayer(i), stddev, gen);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
|
|
@ -84,26 +61,21 @@ void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
|
|||
}
|
||||
}
|
||||
|
||||
// Somewhat duplicates WeightsOwner, but that has neither double nor
|
||||
// Somewhat duplicates `WeightsOwner`, but that has neither double nor
|
||||
// complex types allowed and it would cause code bloat to add them there.
|
||||
template <typename T>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
explicit WeightsWrapper(const ModelConfig& config)
|
||||
: pool_(0), weights_(config) {
|
||||
weights_.Allocate(owners_, pool_);
|
||||
explicit WeightsWrapper(const ModelConfig& config) : weights_(config) {
|
||||
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
|
||||
weights_.AllocateForTest(owners_, pool);
|
||||
}
|
||||
|
||||
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
||||
ModelWeightsPtrs<T>& get() { return weights_; }
|
||||
void ZeroInit() { weights_.ZeroInit(); }
|
||||
void CopyFrom(const WeightsWrapper<T>& other) {
|
||||
weights_.CopyFrom(other.weights_);
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
std::vector<MatOwner> owners_;
|
||||
MatOwners owners_;
|
||||
ModelWeightsPtrs<T> weights_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ cc_library(
|
|||
"//:basics",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
@ -84,9 +85,9 @@ cc_test(
|
|||
":blob_store",
|
||||
":io",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:basics",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -212,15 +213,10 @@ cc_library(
|
|||
],
|
||||
textual_hdrs = ["compress-inl.h"],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":distortion",
|
||||
":fields",
|
||||
":io",
|
||||
":nuq",
|
||||
":sfp",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:common",
|
||||
"//:mat",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
|
|
@ -283,7 +279,6 @@ cc_binary(
|
|||
deps = [
|
||||
":blob_store",
|
||||
":io",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:threading",
|
||||
"//:threading_context",
|
||||
|
|
|
|||
|
|
@ -15,14 +15,15 @@
|
|||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <string.h> // strcmp
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // IndexRange
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
|
|
@ -33,32 +34,56 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
using KeySpan = hwy::Span<const hwy::uint128_t>;
|
||||
|
||||
// Returns false if any keys differ, because then blobs are not comparable.
|
||||
bool CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
|
||||
KeySpan keys1 = reader1.Keys();
|
||||
KeySpan keys2 = reader2.Keys();
|
||||
if (keys1.size() != keys2.size()) {
|
||||
fprintf(stderr, "#keys mismatch: %zu vs %zu\n", keys1.size(), keys2.size());
|
||||
return false;
|
||||
// Aborts if any keys differ, because then blobs are not comparable.
|
||||
void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) {
|
||||
if (reader1.Keys().size() != reader2.Keys().size()) {
|
||||
HWY_ABORT("#keys mismatch: %zu vs %zu\n", reader1.Keys().size(),
|
||||
reader2.Keys().size());
|
||||
}
|
||||
for (size_t i = 0; i < keys1.size(); ++i) {
|
||||
if (keys1[i] != keys2[i]) {
|
||||
fprintf(stderr, "key %zu mismatch: %s vs %s\n", i,
|
||||
StringFromKey(keys1[i]).c_str(), StringFromKey(keys2[i]).c_str());
|
||||
return false;
|
||||
for (size_t i = 0; i < reader1.Keys().size(); ++i) {
|
||||
if (reader1.Keys()[i] != reader2.Keys()[i]) {
|
||||
HWY_ABORT("key %zu mismatch: %s vs %s\n", i, reader1.Keys()[i].c_str(),
|
||||
reader2.Keys()[i].c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
using KeyVec = std::vector<std::string>;
|
||||
using RangeVec = std::vector<BlobRange2>;
|
||||
|
||||
RangeVec AllRanges(const KeyVec& keys, const BlobReader2& reader) {
|
||||
RangeVec ranges;
|
||||
ranges.reserve(keys.size());
|
||||
for (const std::string& key : keys) {
|
||||
const BlobRange2* range = reader.Find(key);
|
||||
if (!range) {
|
||||
HWY_ABORT("Key %s not found, but was in KeyVec\n", key.c_str());
|
||||
}
|
||||
ranges.push_back(*range);
|
||||
}
|
||||
return ranges;
|
||||
}
|
||||
|
||||
// Aborts if any sizes differ, because that already guarantees a mismatch.
|
||||
void CompareRangeSizes(const KeyVec& keys, const RangeVec& ranges1,
|
||||
const RangeVec& ranges2) {
|
||||
HWY_ASSERT(keys.size() == ranges1.size());
|
||||
HWY_ASSERT(keys.size() == ranges2.size());
|
||||
for (size_t i = 0; i < ranges1.size(); ++i) {
|
||||
// Tolerate differing key_idx and offset because blobs may be in different
|
||||
// order in the two files.
|
||||
if (ranges1[i].bytes != ranges2[i].bytes) {
|
||||
HWY_ABORT("range #%zu (%s) size mismatch: %zu vs %zu\n", i,
|
||||
keys[i].c_str(), ranges1[i].bytes, ranges2[i].bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Total amount to allocate for all blobs.
|
||||
size_t TotalBytes(BlobReader& reader) {
|
||||
size_t TotalBytes(const RangeVec& ranges) {
|
||||
size_t total_bytes = 0;
|
||||
for (const hwy::uint128_t key : reader.Keys()) {
|
||||
total_bytes += reader.BlobSize(key);
|
||||
for (const BlobRange2& range : ranges) {
|
||||
total_bytes += range.bytes;
|
||||
}
|
||||
return total_bytes;
|
||||
}
|
||||
|
|
@ -67,55 +92,56 @@ using BytePtr = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
|||
using ByteSpan = hwy::Span<uint8_t>; // Sections within BytePtr
|
||||
using BlobVec = std::vector<ByteSpan>; // in order of keys
|
||||
|
||||
// Allocates memory within the single allocation and updates `pos`.
|
||||
BlobVec ReserveMemory(BlobReader& reader, BytePtr& all_blobs, size_t& pos) {
|
||||
// Assigns pointers within the single allocation and updates `pos`.
|
||||
BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
|
||||
BlobVec blobs;
|
||||
for (const hwy::uint128_t key : reader.Keys()) {
|
||||
const size_t bytes = reader.BlobSize(key);
|
||||
blobs.push_back(ByteSpan(all_blobs.get() + pos, bytes));
|
||||
pos += bytes;
|
||||
for (const BlobRange2& range : ranges) {
|
||||
blobs.push_back(ByteSpan(all_blobs.get() + pos, range.bytes));
|
||||
pos += range.bytes;
|
||||
}
|
||||
return blobs;
|
||||
}
|
||||
|
||||
// Reads one set of blobs in parallel (helpful if in disk cache).
|
||||
void ReadBlobs(BlobReader& reader, BlobVec& blobs, hwy::ThreadPool& pool) {
|
||||
// Aborts on error.
|
||||
void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
HWY_ASSERT(ranges.size() == blobs.size());
|
||||
for (size_t i = 0; i < blobs.size(); ++i) {
|
||||
reader.Enqueue(reader.Keys()[i], blobs[i].data(), blobs[i].size());
|
||||
}
|
||||
const BlobError err = reader.ReadAll(pool);
|
||||
if (err != 0) {
|
||||
HWY_ABORT("Parallel read failed: %d\n", err);
|
||||
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
||||
reader.Enqueue(ranges[i], blobs[i].data());
|
||||
}
|
||||
reader.ReadAll(pool);
|
||||
}
|
||||
|
||||
// Parallelizes ReadBlobs across (two) packages, if available.
|
||||
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, size_t total_bytes,
|
||||
BlobVec& blobs1, BlobVec& blobs2, NestedPools& pools) {
|
||||
void ReadBothBlobs(BlobReader2& reader1, BlobReader2& reader2,
|
||||
const RangeVec& ranges1, const RangeVec& ranges2,
|
||||
size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
|
||||
NestedPools& pools) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
fprintf(stderr, "Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
|
||||
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
|
||||
HWY_WARN("Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
|
||||
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
|
||||
pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) {
|
||||
ReadBlobs(task ? reader2 : reader1, task ? blobs2 : blobs1,
|
||||
pools.Pool(pkg_idx));
|
||||
ReadBlobs(task ? reader2 : reader1, task ? ranges2 : ranges1,
|
||||
task ? blobs2 : blobs1, pools.Pool(pkg_idx));
|
||||
});
|
||||
const double t1 = hwy::platform::Now();
|
||||
fprintf(stderr, "%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
|
||||
HWY_WARN("%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
|
||||
}
|
||||
|
||||
// Returns number of elements with a mismatch. For float and bf16 blobs, uses
|
||||
// L1 and relative error, otherwise byte-wise comparison.
|
||||
size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
||||
const hwy::uint128_t key) {
|
||||
size_t BlobDifferences(const ByteSpan data1, const ByteSpan data2,
|
||||
const std::string& key) {
|
||||
if (data1.size() != data2.size() || data1.size() == 0) {
|
||||
HWY_ABORT("key %s size mismatch: %zu vs %zu\n", StringFromKey(key).c_str(),
|
||||
data1.size(), data2.size());
|
||||
HWY_ABORT("key %s size mismatch: %zu vs %zu\n", key.c_str(), data1.size(),
|
||||
data2.size());
|
||||
}
|
||||
|
||||
size_t mismatches = 0;
|
||||
char type;
|
||||
hwy::CopyBytes(&key, &type, 1);
|
||||
const char type = key[0];
|
||||
if (type == 'F') {
|
||||
HWY_ASSERT(data1.size() % sizeof(float) == 0);
|
||||
for (size_t j = 0; j < data1.size(); j += sizeof(float)) {
|
||||
|
|
@ -125,8 +151,7 @@ size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
|||
const float l1 = hwy::ScalarAbs(f1 - f2);
|
||||
const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1;
|
||||
if (l1 > 1E-3f || rel > 1E-2f) {
|
||||
fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n",
|
||||
StringFromKey(key).c_str(), j, l1, rel);
|
||||
HWY_WARN("key %s %5zu: L1 %.5f rel %.4f\n", key.c_str(), j, l1, rel);
|
||||
++mismatches;
|
||||
}
|
||||
}
|
||||
|
|
@ -140,8 +165,7 @@ size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
|||
const float l1 = hwy::ScalarAbs(f1 - f2);
|
||||
const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1;
|
||||
if (l1 > 1E-2f || rel > 1E-1f) {
|
||||
fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n",
|
||||
StringFromKey(key).c_str(), j, l1, rel);
|
||||
HWY_WARN("key %s %5zu: L1 %.5f rel %.4f\n", key.c_str(), j, l1, rel);
|
||||
++mismatches;
|
||||
}
|
||||
}
|
||||
|
|
@ -149,8 +173,7 @@ size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
|||
for (size_t j = 0; j < data1.size(); ++j) {
|
||||
if (data1[j] != data2[j]) {
|
||||
if (mismatches == 0) {
|
||||
fprintf(stderr, "key %s mismatch at byte %5zu\n",
|
||||
StringFromKey(key).c_str(), j);
|
||||
HWY_WARN("key %s mismatch at byte %5zu\n", key.c_str(), j);
|
||||
}
|
||||
++mismatches;
|
||||
}
|
||||
|
|
@ -159,9 +182,9 @@ size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
|||
return mismatches;
|
||||
}
|
||||
|
||||
void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2,
|
||||
void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
||||
size_t total_bytes, NestedPools& pools) {
|
||||
fprintf(stderr, "Comparing %zu blobs in parallel: ", keys.size());
|
||||
HWY_WARN("Comparing %zu blobs in parallel: ", keys.size());
|
||||
const double t0 = hwy::platform::Now();
|
||||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
|
|
@ -175,9 +198,8 @@ void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2,
|
|||
const size_t mismatches =
|
||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||
if (mismatches != 0) {
|
||||
fprintf(stderr, "key %s has %zu mismatches in %zu bytes!\n",
|
||||
StringFromKey(keys[i]).c_str(), mismatches,
|
||||
blobs1[i].size());
|
||||
HWY_WARN("key %s has %zu mismatches in %zu bytes!\n",
|
||||
keys[i].c_str(), mismatches, blobs1[i].size());
|
||||
blobs_diff.fetch_add(1);
|
||||
} else {
|
||||
blobs_equal.fetch_add(1);
|
||||
|
|
@ -185,35 +207,39 @@ void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2,
|
|||
});
|
||||
});
|
||||
const double t1 = hwy::platform::Now();
|
||||
fprintf(stderr, "%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
|
||||
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
|
||||
blobs_diff.load());
|
||||
HWY_WARN("%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
|
||||
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
|
||||
blobs_diff.load());
|
||||
}
|
||||
|
||||
// Compares two sbs files, including blob order.
|
||||
void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
||||
// Open files.
|
||||
BlobReader reader1;
|
||||
BlobReader reader2;
|
||||
const BlobError err1 = reader1.Open(Path(path1));
|
||||
const BlobError err2 = reader2.Open(Path(path2));
|
||||
if (err1 != 0 || err2 != 0) {
|
||||
HWY_ABORT("Failed to open files: %s %s: %d %d\n", path1, path2, err1, err2);
|
||||
const Tristate map = Tristate::kFalse;
|
||||
std::unique_ptr<BlobReader2> reader1 = BlobReader2::Make(Path(path1), map);
|
||||
std::unique_ptr<BlobReader2> reader2 = BlobReader2::Make(Path(path2), map);
|
||||
if (!reader1 || !reader2) {
|
||||
HWY_ABORT(
|
||||
"Failed to create readers for files %s %s, see error messages above.\n",
|
||||
path1, path2);
|
||||
}
|
||||
|
||||
if (!CompareKeys(reader1, reader2)) return;
|
||||
CompareKeys(*reader1, *reader2);
|
||||
const RangeVec ranges1 = AllRanges(reader1->Keys(), *reader1);
|
||||
const RangeVec ranges2 = AllRanges(reader2->Keys(), *reader2);
|
||||
CompareRangeSizes(reader1->Keys(), ranges1, ranges2);
|
||||
|
||||
// Single allocation, avoid initializing the memory.
|
||||
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
|
||||
const size_t total_bytes = TotalBytes(ranges1) + TotalBytes(ranges2);
|
||||
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
|
||||
size_t pos = 0;
|
||||
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
|
||||
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
|
||||
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
|
||||
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
|
||||
|
||||
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
|
||||
ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1,
|
||||
blobs2, pools);
|
||||
|
||||
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
||||
CompareBlobs(reader1->Keys(), blobs1, blobs2, total_bytes, pools);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -18,28 +18,48 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
hwy::uint128_t MakeKey(const char* string) {
|
||||
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||
|
||||
// Each blob offset is a multiple of this, an upper bound on SVE vectors and
|
||||
// usually also larger than L2 cache lines. This is useful when memory mapping
|
||||
// the entire file, because offset alignment then determines the alignment of
|
||||
// the blob in memory. Aligning each blob to the (largest) page size would be
|
||||
// too wasteful, see `kEndAlign`.
|
||||
constexpr size_t kBlobAlign = 256; // test also hard-codes this value
|
||||
|
||||
// Linux mmap requires the file to be a multiple of the (base) page size, which
|
||||
// can be up to 64 KiB on Arm. Apple uses 16 KiB, most others use 4 KiB.
|
||||
constexpr size_t kEndAlign = 64 * 1024;
|
||||
|
||||
constexpr size_t kU128Bytes = sizeof(hwy::uint128_t);
|
||||
|
||||
// Conversion between strings (<= `kU128Bytes` chars) and the fixed-size u128
|
||||
// used to store them on disk.
|
||||
static hwy::uint128_t KeyFromString(const char* string) {
|
||||
size_t length = 0;
|
||||
for (size_t i = 0; string[i] != '\0'; ++i) {
|
||||
++length;
|
||||
}
|
||||
if (length > 16) {
|
||||
if (length > kU128Bytes) {
|
||||
HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string);
|
||||
}
|
||||
HWY_ASSERT(length != 0);
|
||||
|
||||
hwy::uint128_t ret;
|
||||
hwy::ZeroBytes<sizeof(ret)>(&ret);
|
||||
|
|
@ -47,7 +67,7 @@ hwy::uint128_t MakeKey(const char* string) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::string StringFromKey(hwy::uint128_t key) {
|
||||
static std::string StringFromKey(hwy::uint128_t key) {
|
||||
std::string name(sizeof(key) + 1, '\0');
|
||||
hwy::CopyBytes(&key, name.data(), sizeof(key));
|
||||
name.resize(name.find('\0'));
|
||||
|
|
@ -55,287 +75,456 @@ std::string StringFromKey(hwy::uint128_t key) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
||||
std::vector<BlobIO>& requests) {
|
||||
// Split into chunks for load-balancing even if blob sizes vary.
|
||||
constexpr size_t kChunkSize = 4 * 1024 * 1024; // bytes
|
||||
|
||||
// Split into whole chunks and possibly one remainder.
|
||||
uint64_t pos = 0;
|
||||
if (size >= kChunkSize) {
|
||||
for (; pos <= size - kChunkSize; pos += kChunkSize) {
|
||||
requests.emplace_back(offset + pos, kChunkSize, data + pos, 0);
|
||||
}
|
||||
}
|
||||
if (pos != size) {
|
||||
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
|
||||
}
|
||||
}
|
||||
#pragma pack(push, 1)
|
||||
struct Header { // standard layout class
|
||||
uint32_t magic = 0; // kMagic
|
||||
uint32_t num_blobs = 0; // never zero
|
||||
uint64_t file_bytes = 0; // must match actual size of file
|
||||
};
|
||||
#pragma pack(pop)
|
||||
static_assert(sizeof(Header) == 16);
|
||||
} // namespace
|
||||
|
||||
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||
|
||||
// On-disk representation (little-endian).
|
||||
// Little-endian on-disk representation: a fixed-size `Header`, then a padded
|
||||
// variable-length 'directory' of blob keys and their offset/sizes, then the
|
||||
// 'payload' of each blob's data with padding in between, followed by padding to
|
||||
// `kEndAlign`. Keys are unique, opaque 128-bit keys.
|
||||
//
|
||||
// Deliberately omits a version number because this file format is unchanging.
|
||||
// The file format deliberately omits a version number because it is unchanging.
|
||||
// Additional data may be added only inside new blobs. Changes to the blob
|
||||
// contents or type should be handled by renaming keys.
|
||||
#pragma pack(push, 1)
|
||||
//
|
||||
// This class is for internal use by `BlobReader2` and `BlobWriter2`. Its
|
||||
// interface is more low-level: fixed-size keys instead of strings.
|
||||
class BlobStore {
|
||||
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
|
||||
|
||||
// Arbitrary upper limit to avoid allocating a huge vector.
|
||||
static constexpr size_t kMaxBlobs = 64 * 1024;
|
||||
|
||||
// Returns the end of the directory, including padding, which is also the
|
||||
// start of the first payload. `num_blobs` is `NumBlobs()` if the header is
|
||||
// already available, otherwise the number of blobs to be written.
|
||||
static constexpr size_t PaddedDirEnd(size_t num_blobs) {
|
||||
HWY_ASSERT(num_blobs < kMaxBlobs);
|
||||
// Per blob, a key and offset/size.
|
||||
return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs);
|
||||
}
|
||||
|
||||
static uint64_t PaddedPayloadBytes(size_t num_blobs,
|
||||
const hwy::Span<const uint8_t> blobs[]) {
|
||||
uint64_t total_payload_bytes = 0;
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
total_payload_bytes += RoundUpToAlign(blobs[i].size());
|
||||
}
|
||||
// Do not round up to `kEndAlign` because the padding also depends on the
|
||||
// directory size. Here we only count the payload.
|
||||
return total_payload_bytes;
|
||||
}
|
||||
|
||||
static void EnsureUnique(hwy::Span<const hwy::uint128_t> keys) {
|
||||
std::unordered_set<std::string> key_set;
|
||||
for (const hwy::uint128_t key : keys) {
|
||||
HWY_ASSERT(key_set.insert(StringFromKey(key)).second); // ensure inserted
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// NOT including padding, so that we can also use ZeroFillPadding after
|
||||
// copying the header.
|
||||
static constexpr size_t HeaderSize(size_t num_blobs) {
|
||||
// 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size.
|
||||
return 16 + 32 * num_blobs;
|
||||
template <typename T>
|
||||
static T RoundUpToAlign(T size_or_offset) {
|
||||
return hwy::RoundUpTo(size_or_offset, kBlobAlign);
|
||||
}
|
||||
|
||||
// Returns how many bytes to allocate for the header without the subsequent
|
||||
// blobs. Requires num_blobs_ to already be set, typically by reading
|
||||
// sizeof(BlobStore) bytes from disk.
|
||||
size_t PaddedHeaderSize() const {
|
||||
return hwy::RoundUpTo(HeaderSize(num_blobs_), kBlobAlign);
|
||||
}
|
||||
|
||||
// Returns aligned offset and zero-fills between that and `offset`.
|
||||
uint64_t ZeroFillPadding(uint64_t offset) {
|
||||
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
|
||||
const uint64_t padded = hwy::RoundUpTo(offset, kBlobAlign);
|
||||
hwy::ZeroBytes(bytes + offset, padded - offset);
|
||||
return padded;
|
||||
}
|
||||
|
||||
BlobError CheckValidity(const uint64_t file_size) {
|
||||
if (magic_ != kMagic) return __LINE__;
|
||||
if (num_blobs_ == 0) return __LINE__;
|
||||
if (file_size_ != file_size) return __LINE__;
|
||||
|
||||
// Ensure blobs are back to back, and zero-pad.
|
||||
uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_));
|
||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||
if (val.lo != offset) return __LINE__;
|
||||
offset = hwy::RoundUpTo(offset + val.hi, kBlobAlign);
|
||||
// Reads header/directory from file.
|
||||
explicit BlobStore(const File& file) {
|
||||
if (!file.Read(0, sizeof(header_), &header_)) {
|
||||
HWY_WARN("Failed to read BlobStore header.");
|
||||
return;
|
||||
}
|
||||
// Avoid allocating a huge vector.
|
||||
if (header_.num_blobs >= kMaxBlobs) {
|
||||
HWY_WARN("Too many blobs, likely corrupt file.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (offset != file_size_) return __LINE__;
|
||||
|
||||
return 0; // all OK
|
||||
const size_t padded_dir_end = PaddedDirEnd(NumBlobs());
|
||||
const size_t padded_dir_bytes = padded_dir_end - sizeof(header_);
|
||||
HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0);
|
||||
directory_.resize(padded_dir_bytes / kU128Bytes);
|
||||
if (!file.Read(sizeof(header_), padded_dir_bytes, directory_.data())) {
|
||||
HWY_WARN("Failed to read BlobStore directory.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static BlobStorePtr Allocate(uint64_t total_size) {
|
||||
uint8_t* bytes =
|
||||
static_cast<uint8_t*>(hwy::AllocateAlignedBytes(total_size));
|
||||
if (!bytes) return BlobStorePtr();
|
||||
return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer());
|
||||
}
|
||||
// Initializes header/directory for writing to disk.
|
||||
BlobStore(size_t num_blobs, const hwy::uint128_t keys[],
|
||||
const hwy::Span<const uint8_t> blobs[]) {
|
||||
HWY_ASSERT(num_blobs < kMaxBlobs); // Ensures safe to cast to u32.
|
||||
HWY_ASSERT(keys && blobs);
|
||||
EnsureUnique(hwy::Span<const hwy::uint128_t>(keys, num_blobs));
|
||||
|
||||
static std::vector<BlobIO> PrepareWriteRequests(
|
||||
const hwy::uint128_t keys[], const hwy::Span<const uint8_t> blobs[],
|
||||
size_t num_blobs, BlobStore* bs) {
|
||||
// Sanity check and ensure the cast below is safe.
|
||||
HWY_ASSERT(num_blobs < (1ULL << 20));
|
||||
uint64_t offset = PaddedDirEnd(num_blobs);
|
||||
const size_t padded_dir_bytes =
|
||||
static_cast<size_t>(offset) - sizeof(header_);
|
||||
|
||||
// Allocate var-length header.
|
||||
const size_t header_size = HeaderSize(num_blobs);
|
||||
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
|
||||
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
|
||||
HWY_ASSERT(padded_header_end == padded_header_size);
|
||||
header_.magic = kMagic;
|
||||
header_.num_blobs = static_cast<uint32_t>(num_blobs);
|
||||
header_.file_bytes = hwy::RoundUpTo(
|
||||
offset + PaddedPayloadBytes(num_blobs, blobs), kEndAlign);
|
||||
|
||||
// All-zero buffer used to write padding to the file without copying the
|
||||
// input blobs.
|
||||
static uint8_t zeros[kBlobAlign] = {0};
|
||||
HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0);
|
||||
directory_.resize(padded_dir_bytes / kU128Bytes);
|
||||
hwy::CopyBytes(keys, directory_.data(), num_blobs * kU128Bytes);
|
||||
EnsureUnique(Keys());
|
||||
// `SetRange` below will fill `directory_[num_blobs, 2 * num_blobs)`.
|
||||
hwy::ZeroBytes(directory_.data() + 2 * num_blobs,
|
||||
padded_dir_bytes - 2 * num_blobs * kU128Bytes);
|
||||
|
||||
// Total file size will be the header plus all padded blobs.
|
||||
uint64_t payload = 0;
|
||||
// We already zero-initialized the directory padding;
|
||||
// `BlobWriter2::WriteAll` takes care of padding after each blob via an
|
||||
// additional I/O.
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
payload += hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
|
||||
HWY_ASSERT(blobs[i].data() != nullptr);
|
||||
SetRange(i, offset, blobs[i].size());
|
||||
offset = RoundUpToAlign(offset + blobs[i].size());
|
||||
}
|
||||
const size_t total_size = padded_header_size + payload;
|
||||
|
||||
// Fill header.
|
||||
bs->magic_ = kMagic;
|
||||
bs->num_blobs_ = static_cast<uint32_t>(num_blobs);
|
||||
bs->file_size_ = total_size;
|
||||
hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0]));
|
||||
|
||||
// First IO request is for the header (not yet filled!).
|
||||
std::vector<BlobIO> requests;
|
||||
requests.reserve(1 + 2 * num_blobs);
|
||||
requests.emplace_back(/*offset=*/0, padded_header_size,
|
||||
reinterpret_cast<uint8_t*>(bs), 0);
|
||||
|
||||
// Fill second half of keys_ with offset/size and prepare IO requests.
|
||||
uint64_t offset = padded_header_end;
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
bs->keys_[num_blobs + i].lo = offset;
|
||||
bs->keys_[num_blobs + i].hi = blobs[i].size();
|
||||
|
||||
EnqueueChunkRequests(offset, blobs[i].size(),
|
||||
const_cast<uint8_t*>(blobs[i].data()), requests);
|
||||
offset += blobs[i].size();
|
||||
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
|
||||
if (padded_size != blobs[i].size()) {
|
||||
const size_t padding = padded_size - blobs[i].size();
|
||||
HWY_ASSERT(padding <= kBlobAlign);
|
||||
requests.emplace_back(offset, padding, zeros, 0);
|
||||
offset += padding;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(offset == total_size);
|
||||
return requests;
|
||||
// When writing new files, we always pad to `kEndAlign`.
|
||||
HWY_ASSERT(hwy::RoundUpTo(offset, kEndAlign) == header_.file_bytes);
|
||||
}
|
||||
|
||||
bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const {
|
||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||
if (keys_[i] == key) {
|
||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||
offset = val.lo;
|
||||
size = val.hi;
|
||||
return true;
|
||||
}
|
||||
// Must be checked by readers before other methods.
|
||||
bool IsValid(const uint64_t file_size) const {
|
||||
// Ctor failed and already printed a warning.
|
||||
if (directory_.empty()) return false;
|
||||
|
||||
if (header_.magic != kMagic) {
|
||||
HWY_WARN("Given file is not a BlobStore (magic %08x).", header_.magic);
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
if (header_.num_blobs == 0) {
|
||||
HWY_WARN("Invalid BlobStore (empty), likely corrupt file.");
|
||||
return false;
|
||||
}
|
||||
if (header_.file_bytes != file_size) {
|
||||
HWY_WARN("File length %zu does not match header %zu (truncated?).",
|
||||
static_cast<size_t>(file_size),
|
||||
static_cast<size_t>(header_.file_bytes));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Ensure blobs are back to back.
|
||||
uint64_t expected_offset = PaddedDirEnd(NumBlobs());
|
||||
for (size_t key_idx = 0; key_idx < NumBlobs(); ++key_idx) {
|
||||
uint64_t actual_offset;
|
||||
size_t bytes;
|
||||
GetRange(key_idx, actual_offset, bytes);
|
||||
if (expected_offset != actual_offset) {
|
||||
HWY_WARN("Invalid BlobStore: blob %zu at offset %zu but expected %zu.",
|
||||
key_idx, static_cast<size_t>(actual_offset),
|
||||
static_cast<size_t>(expected_offset));
|
||||
return false;
|
||||
}
|
||||
expected_offset = RoundUpToAlign(expected_offset + bytes);
|
||||
}
|
||||
// Previously files were not padded to `kEndAlign`, so also allow that.
|
||||
if (expected_offset != header_.file_bytes &&
|
||||
hwy::RoundUpTo(expected_offset, kEndAlign) != header_.file_bytes) {
|
||||
HWY_WARN("Invalid BlobStore: end of blobs %zu but file size %zu.",
|
||||
static_cast<size_t>(expected_offset),
|
||||
static_cast<size_t>(header_.file_bytes));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true; // all OK
|
||||
}
|
||||
|
||||
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
|
||||
const size_t key_idx = 0; // not actually associated with a key/blob
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
|
||||
// members are const and BlobIO2 requires non-const pointers, and they
|
||||
// are not modified by file writes.
|
||||
const_cast<Header*>(&header_));
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = sizeof(header_),
|
||||
.bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_),
|
||||
.key_idx = key_idx},
|
||||
const_cast<hwy::uint128_t*>(directory_.data()));
|
||||
}
|
||||
|
||||
size_t NumBlobs() const { return static_cast<size_t>(header_.num_blobs); }
|
||||
|
||||
// Not the entirety of `directory_`! The second half is offset/size.
|
||||
hwy::Span<const hwy::uint128_t> Keys() const {
|
||||
return hwy::Span<const hwy::uint128_t>(keys_, num_blobs_);
|
||||
return hwy::Span<const hwy::uint128_t>(directory_.data(), NumBlobs());
|
||||
}
|
||||
|
||||
// Retrieves blob's offset and size, not including padding.
|
||||
void GetRange(size_t key_idx, uint64_t& offset, size_t& bytes) const {
|
||||
HWY_ASSERT(key_idx < NumBlobs());
|
||||
const hwy::uint128_t val = directory_[NumBlobs() + key_idx];
|
||||
offset = val.lo;
|
||||
bytes = val.hi;
|
||||
HWY_ASSERT(offset % kBlobAlign == 0);
|
||||
HWY_ASSERT(bytes != 0);
|
||||
HWY_ASSERT(offset + bytes <= header_.file_bytes);
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t magic_;
|
||||
uint32_t num_blobs_; // never 0
|
||||
uint64_t file_size_; // must match actual size of file
|
||||
hwy::uint128_t keys_[1]; // length: 2 * num_blobs
|
||||
// Padding, then the blob identified by keys[0], then padding etc.
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
BlobError BlobReader::Open(const Path& filename) {
|
||||
file_ = OpenFileOrNull(filename, "r");
|
||||
if (!file_) return __LINE__;
|
||||
|
||||
// Read first part of header to get actual size.
|
||||
BlobStore bs;
|
||||
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
|
||||
const size_t padded_size = bs.PaddedHeaderSize();
|
||||
HWY_ASSERT(padded_size >= sizeof(bs));
|
||||
|
||||
// Allocate full header.
|
||||
blob_store_ = BlobStore::Allocate(padded_size);
|
||||
if (!blob_store_) return __LINE__;
|
||||
|
||||
// Copy what we already read (more efficient than seek + re-read).
|
||||
hwy::CopySameSize(&bs, blob_store_.get());
|
||||
// Read the rest of the header, but not the full file.
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
||||
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
|
||||
return __LINE__;
|
||||
// Stores offset and range into u128 following the keys, so the directory
|
||||
// can be one array of the same type, and read/written together with keys.
|
||||
void SetRange(size_t key_idx, uint64_t offset, size_t bytes) {
|
||||
HWY_ASSERT(key_idx < NumBlobs());
|
||||
HWY_ASSERT(offset % kBlobAlign == 0);
|
||||
HWY_ASSERT(bytes != 0);
|
||||
HWY_ASSERT(offset + bytes <= header_.file_bytes);
|
||||
hwy::uint128_t& val = directory_[NumBlobs() + key_idx];
|
||||
val.lo = offset;
|
||||
val.hi = bytes;
|
||||
}
|
||||
|
||||
return blob_store_->CheckValidity(file_->FileSize());
|
||||
}
|
||||
Header header_;
|
||||
|
||||
size_t BlobReader::BlobSize(hwy::uint128_t key) const {
|
||||
uint64_t offset;
|
||||
size_t size;
|
||||
if (!blob_store_->FindKey(key, offset, size)) return 0;
|
||||
return size;
|
||||
}
|
||||
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
|
||||
}; // BlobStore
|
||||
|
||||
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||
uint64_t offset;
|
||||
size_t actual_size;
|
||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||
if (actual_size != size) {
|
||||
fprintf(stderr,
|
||||
"Mismatch between expected %d and actual %d KiB size of blob %s. "
|
||||
"Please see README.md on how to update the weights.\n",
|
||||
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
|
||||
StringFromKey(key).c_str());
|
||||
return __LINE__;
|
||||
BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||
const BlobStore& bs, BlobReader2::Mode mode)
|
||||
: file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) {
|
||||
HWY_ASSERT(file_ && file_bytes_ != 0);
|
||||
|
||||
keys_.reserve(bs.NumBlobs());
|
||||
for (const hwy::uint128_t key : bs.Keys()) {
|
||||
keys_.push_back(StringFromKey(key));
|
||||
}
|
||||
|
||||
EnqueueChunkRequests(offset, actual_size, reinterpret_cast<uint8_t*>(data),
|
||||
requests_);
|
||||
return 0;
|
||||
ranges_.reserve(bs.NumBlobs());
|
||||
// Populate hash map for O(1) lookups.
|
||||
for (size_t key_idx = 0; key_idx < keys_.size(); ++key_idx) {
|
||||
uint64_t offset;
|
||||
size_t bytes;
|
||||
bs.GetRange(key_idx, offset, bytes);
|
||||
ranges_.emplace_back(
|
||||
BlobRange2{.offset = offset, .bytes = bytes, .key_idx = key_idx});
|
||||
key_idx_for_key_[keys_[key_idx]] = key_idx;
|
||||
}
|
||||
|
||||
if (mode_ == Mode::kMap) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
// Verify `kEndAlign` is an upper bound on the page size.
|
||||
if (kEndAlign % allocator.BasePageBytes() != 0) {
|
||||
HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.",
|
||||
kEndAlign, allocator.BasePageBytes());
|
||||
}
|
||||
if (file_bytes_ % allocator.BasePageBytes() == 0) {
|
||||
mapped_ = file_->Map();
|
||||
if (!mapped_) {
|
||||
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
||||
static_cast<size_t>(file_bytes_ >> 10));
|
||||
mode_ = Mode::kRead; // Switch to kRead and continue.
|
||||
}
|
||||
} else {
|
||||
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
||||
static_cast<size_t>(file_bytes_ >> 10),
|
||||
allocator.BasePageBytes());
|
||||
mode_ = Mode::kRead; // Switch to kRead and continue.
|
||||
}
|
||||
}
|
||||
|
||||
if (mode_ == Mode::kRead) {
|
||||
// Potentially one per tensor row, so preallocate many.
|
||||
requests_.reserve(2 << 20);
|
||||
}
|
||||
}
|
||||
|
||||
void BlobReader2::Enqueue(const BlobRange2& range, void* data) {
|
||||
// Debug-only because there may be many I/O requests (per row).
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
HWY_DASSERT(!IsMapped());
|
||||
HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr);
|
||||
const BlobRange2& blob_range = Range(range.key_idx);
|
||||
HWY_DASSERT(blob_range.End() <= file_bytes_);
|
||||
if (range.End() > blob_range.End()) {
|
||||
HWY_ABORT(
|
||||
"Bug: want to read %zu bytes of %s until %zu, past blob end %zu.",
|
||||
range.bytes, keys_[range.key_idx].c_str(),
|
||||
static_cast<size_t>(range.End()),
|
||||
static_cast<size_t>(blob_range.End()));
|
||||
}
|
||||
}
|
||||
requests_.emplace_back(range, data);
|
||||
}
|
||||
|
||||
// Parallel synchronous I/O. Alternatives considered:
|
||||
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
|
||||
// pread calls preadv with a single iovec.
|
||||
// TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX.
|
||||
// - O_DIRECT seems undesirable because we do want to use the OS cache
|
||||
// between consecutive runs.
|
||||
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
||||
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||
File* pfile = file_.get(); // not owned
|
||||
const auto& requests = requests_;
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
void BlobReader2::ReadAll(hwy::ThreadPool& pool) const {
|
||||
PROFILER_ZONE("Startup.ReadAll");
|
||||
HWY_ASSERT(!IsMapped());
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, requests.size(),
|
||||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Read(requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
fprintf(stderr, "Failed to read blob %zu\n",
|
||||
static_cast<size_t>(i));
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange2& range = requests_[i].range;
|
||||
const uint64_t end = range.End();
|
||||
const std::string& key = keys_[range.key_idx];
|
||||
const BlobRange2& blob_range = Range(range.key_idx);
|
||||
HWY_ASSERT(blob_range.End() <= file_bytes_);
|
||||
if (end > blob_range.End()) {
|
||||
HWY_ABORT(
|
||||
"Bug: want to read %zu bytes of %s until %zu, past blob end %zu.",
|
||||
range.bytes, key.c_str(), static_cast<size_t>(end),
|
||||
static_cast<size_t>(blob_range.End()));
|
||||
}
|
||||
if (!file_->Read(range.offset, range.bytes, requests_[i].data)) {
|
||||
HWY_ABORT("Read failed for %s from %zu, %zu bytes to %p.", key.c_str(),
|
||||
static_cast<size_t>(range.offset), range.bytes,
|
||||
requests_[i].data);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data,
|
||||
size_t size) const {
|
||||
uint64_t offset;
|
||||
size_t actual_size;
|
||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||
if (actual_size != size) {
|
||||
fprintf(stderr,
|
||||
"Mismatch between expected %d and actual %d KiB size of blob %s. "
|
||||
"Please see README.md on how to update the weights.\n",
|
||||
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
|
||||
StringFromKey(key).c_str());
|
||||
return __LINE__;
|
||||
// Decides whether to read or map the file.
|
||||
static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
// User has explicitly requested a map or read via args.
|
||||
if (map == Tristate::kTrue) return BlobReader2::Mode::kMap;
|
||||
if (map == Tristate::kFalse) return BlobReader2::Mode::kRead;
|
||||
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
|
||||
// because idle memory is used as cache, so do not use it to decide.
|
||||
const size_t total_mib = allocator.TotalMiB();
|
||||
if (file_mib > total_mib) {
|
||||
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
||||
static_cast<size_t>(file_mib), total_mib);
|
||||
}
|
||||
if (!file_->Read(offset, actual_size, data)) {
|
||||
return __LINE__;
|
||||
// Large fraction of total.
|
||||
if (file_mib >= total_mib / 3) return BlobReader2::Mode::kMap;
|
||||
// Big enough that even parallel loading wouldn't be quick.
|
||||
if (file_mib > 50 * 1024) return BlobReader2::Mode::kMap;
|
||||
return BlobReader2::Mode::kRead;
|
||||
}
|
||||
|
||||
std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path,
|
||||
const Tristate map) {
|
||||
if (blob_path.Empty()) HWY_ABORT("No --weights specified.");
|
||||
std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r");
|
||||
if (!file) HWY_ABORT("Failed to open file %s", blob_path.path.c_str());
|
||||
const uint64_t file_bytes = file->FileSize();
|
||||
if (file_bytes == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
|
||||
|
||||
// Even if `kMap`, read the directory via the `kRead` mode for simplicity.
|
||||
BlobStore bs(*file);
|
||||
if (!bs.IsValid(file_bytes)) {
|
||||
return std::unique_ptr<BlobReader2>(); // IsValid already printed a warning
|
||||
}
|
||||
return 0;
|
||||
|
||||
return std::unique_ptr<BlobReader2>(new BlobReader2(
|
||||
std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map)));
|
||||
}
|
||||
|
||||
hwy::Span<const hwy::uint128_t> BlobReader::Keys() const {
|
||||
return blob_store_->Keys();
|
||||
// Split into chunks for load-balancing even if blob sizes vary.
|
||||
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
||||
uint8_t* data, std::vector<BlobIO2>& writes) {
|
||||
constexpr size_t kChunkBytes = 4 * 1024 * 1024;
|
||||
const uint64_t end = offset + bytes;
|
||||
// Split into whole chunks and possibly one remainder.
|
||||
if (end >= kChunkBytes) {
|
||||
for (; offset <= end - kChunkBytes;
|
||||
offset += kChunkBytes, data += kChunkBytes) {
|
||||
writes.emplace_back(
|
||||
BlobRange2{
|
||||
.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
|
||||
data);
|
||||
}
|
||||
}
|
||||
if (offset != end) {
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
|
||||
data);
|
||||
}
|
||||
}
|
||||
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||
static void EnqueueWritesForBlobs(const BlobStore& bs,
|
||||
const hwy::Span<const uint8_t> blobs[],
|
||||
std::vector<uint8_t>& zeros,
|
||||
std::vector<BlobIO2>& writes) {
|
||||
// All-zero buffer used to write padding to the file without copying the
|
||||
// input blobs.
|
||||
static constexpr uint8_t kZeros[kBlobAlign] = {0};
|
||||
|
||||
// Concatenate blobs in memory.
|
||||
const size_t header_size = BlobStore::HeaderSize(keys_.size());
|
||||
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
|
||||
const BlobStorePtr bs = BlobStore::Allocate(padded_header_size);
|
||||
const std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
|
||||
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
||||
uint64_t file_end = 0; // for padding
|
||||
for (size_t key_idx = 0; key_idx < bs.NumBlobs(); ++key_idx) {
|
||||
// We know the size, but `BlobStore` tells us the offset to write each blob.
|
||||
uint64_t offset;
|
||||
size_t bytes;
|
||||
bs.GetRange(key_idx, offset, bytes);
|
||||
HWY_ASSERT(offset != 0);
|
||||
HWY_ASSERT(bytes == blobs[key_idx].size());
|
||||
const uint64_t new_file_end = offset + bytes;
|
||||
HWY_ASSERT(new_file_end >= file_end); // blobs are ordered by offset
|
||||
file_end = new_file_end;
|
||||
|
||||
EnqueueChunks(key_idx, offset, bytes,
|
||||
const_cast<uint8_t*>(blobs[key_idx].data()), writes);
|
||||
const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes;
|
||||
if (padding != 0) {
|
||||
HWY_ASSERT(padding <= kBlobAlign);
|
||||
writes.emplace_back(
|
||||
BlobRange2{
|
||||
.offset = offset + bytes, .bytes = padding, .key_idx = key_idx},
|
||||
const_cast<uint8_t*>(kZeros));
|
||||
}
|
||||
}
|
||||
|
||||
const size_t padding = hwy::RoundUpTo(file_end, kEndAlign) - file_end;
|
||||
if (padding != 0) {
|
||||
// Bigger than `kZeros`, better to allocate than issue multiple I/Os. Must
|
||||
// remain alive until the last I/O is done.
|
||||
zeros.resize(padding);
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = file_end, .bytes = padding, .key_idx = 0},
|
||||
zeros.data());
|
||||
}
|
||||
}
|
||||
|
||||
void BlobWriter2::Add(const std::string& key, const void* data, size_t bytes) {
|
||||
HWY_ASSERT(data != nullptr);
|
||||
HWY_ASSERT(bytes != 0);
|
||||
keys_.push_back(KeyFromString(key.c_str()));
|
||||
blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes);
|
||||
}
|
||||
|
||||
void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||
const size_t num_blobs = keys_.size();
|
||||
HWY_ASSERT(num_blobs != 0);
|
||||
HWY_ASSERT(num_blobs == blobs_.size());
|
||||
|
||||
std::vector<BlobIO2> writes;
|
||||
writes.reserve(16384);
|
||||
|
||||
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());
|
||||
bs.EnqueueWriteForHeaderAndDirectory(writes);
|
||||
|
||||
std::vector<uint8_t> zeros;
|
||||
EnqueueWritesForBlobs(bs, blobs_.data(), zeros, writes);
|
||||
|
||||
// Create/replace existing file.
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
|
||||
if (!file) return __LINE__;
|
||||
File* pfile = file.get(); // not owned
|
||||
if (!file) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
|
||||
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
pool.Run(0, requests.size(),
|
||||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Write(requests[i].data, requests[i].size,
|
||||
requests[i].offset)) {
|
||||
err.test_and_set();
|
||||
pool.Run(0, writes.size(),
|
||||
[this, &file, &writes](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange2& range = writes[i].range;
|
||||
|
||||
if (!file->Write(writes[i].data, range.bytes, range.offset)) {
|
||||
const std::string& key = StringFromKey(keys_[range.key_idx]);
|
||||
HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.",
|
||||
key.c_str(), static_cast<size_t>(range.offset),
|
||||
range.bytes, writes[i].data);
|
||||
}
|
||||
});
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,96 +16,160 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
|
||||
// Reads/writes arrays of bytes from/to file.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::uint128_t
|
||||
#include "compression/io.h" // File, Path, MapPtr
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Convenient way to construct a key from a string (<= 16 chars).
|
||||
hwy::uint128_t MakeKey(const char* string);
|
||||
// One blob's extents within the file.
|
||||
struct BlobRange2 {
|
||||
uint64_t End() const { return offset + bytes; }
|
||||
|
||||
// Returns a string from a key.
|
||||
std::string StringFromKey(hwy::uint128_t key);
|
||||
uint64_t offset = 0;
|
||||
size_t bytes = 0; // We check blobs are not zero-sized.
|
||||
// Index within `BlobReader2::Keys()` for error reporting.
|
||||
size_t key_idx;
|
||||
};
|
||||
|
||||
// A read or write I/O request, each serviced by one thread in a pool.
|
||||
struct BlobIO2 {
|
||||
BlobIO2(BlobRange2 range, void* data) : range(range), data(data) {}
|
||||
|
||||
BlobRange2 range;
|
||||
void* data; // Modified only if a read request. Read-only for writes.
|
||||
};
|
||||
|
||||
// Ordered list of opaque blobs (~hundreds), identified by unique opaque
|
||||
// 128-bit keys.
|
||||
class BlobStore;
|
||||
|
||||
// Incomplete type, so dtor will not be called.
|
||||
using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
|
||||
|
||||
// 0 if successful, otherwise the line number of the failing check.
|
||||
using BlobError = int;
|
||||
|
||||
// Blob offsets on disk and memory addresses are a multiple of this, because
|
||||
// we pad the header and each blob's size. This matches CUDA alignment and the
|
||||
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
|
||||
// 128), which can help performance.
|
||||
static constexpr size_t kBlobAlign = 256;
|
||||
|
||||
// One I/O request, serviced by threads in a pool.
|
||||
struct BlobIO {
|
||||
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
|
||||
: offset(offset), size(size), data(data), padding(padding) {}
|
||||
|
||||
uint64_t offset;
|
||||
size_t size; // bytes
|
||||
void* data;
|
||||
uint64_t padding;
|
||||
};
|
||||
|
||||
class BlobReader {
|
||||
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
|
||||
// faster lookups, and reads or maps blob data.
|
||||
// Thread-safe: it is safe to concurrently call all methods except `Enqueue`,
|
||||
// because they are const.
|
||||
// TODO(janwas): split into header and reader/mapper classes.
|
||||
class BlobReader2 {
|
||||
public:
|
||||
BlobReader() { requests_.reserve(500); }
|
||||
~BlobReader() = default;
|
||||
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
||||
// better when the file is huge, but page faults add noise to measurements.
|
||||
enum class Mode { kRead, kMap };
|
||||
|
||||
// Opens `filename` and reads its header.
|
||||
BlobError Open(const Path& filename);
|
||||
// Acquires ownership of `file` (which must be non-null) and reads its header.
|
||||
// Factory function instead of ctor because this can fail (return null).
|
||||
static std::unique_ptr<BlobReader2> Make(const Path& blob_path,
|
||||
Tristate map = Tristate::kDefault);
|
||||
|
||||
// Returns the size of the blob identified by `key`, or 0 if not found.
|
||||
size_t BlobSize(hwy::uint128_t key) const;
|
||||
~BlobReader2() = default;
|
||||
|
||||
// Enqueues read requests if `key` is found and its size matches `size`, which
|
||||
// is in units of bytes.
|
||||
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
|
||||
// Returns true if the mode passed to ctor was `kMap` and mapping succeeded.
|
||||
bool IsMapped() const { return mode_ == Mode::kMap; }
|
||||
|
||||
// Reads all enqueued requests.
|
||||
BlobError ReadAll(hwy::ThreadPool& pool);
|
||||
const std::vector<std::string>& Keys() const { return keys_; }
|
||||
|
||||
// Reads one blob directly.
|
||||
BlobError ReadOne(hwy::uint128_t key, void* data, size_t size) const;
|
||||
|
||||
// Returns all available blob keys.
|
||||
hwy::Span<const hwy::uint128_t> Keys() const;
|
||||
|
||||
private:
|
||||
BlobStorePtr blob_store_; // holds header, not the entire file
|
||||
std::vector<BlobIO> requests_;
|
||||
std::unique_ptr<File> file_;
|
||||
};
|
||||
|
||||
class BlobWriter {
|
||||
public:
|
||||
// `size` is in bytes.
|
||||
void Add(hwy::uint128_t key, const void* data, size_t size) {
|
||||
keys_.push_back(key);
|
||||
blobs_.emplace_back(static_cast<const uint8_t*>(data), size);
|
||||
const BlobRange2& Range(size_t key_idx) const {
|
||||
HWY_ASSERT(key_idx < keys_.size());
|
||||
return ranges_[key_idx];
|
||||
}
|
||||
|
||||
// Stores all blobs to disk in the given order with padding for alignment.
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
||||
// Returns nullptr if not found. O(1).
|
||||
const BlobRange2* Find(const std::string& key) const {
|
||||
auto it = key_idx_for_key_.find(key);
|
||||
if (it == key_idx_for_key_.end()) return nullptr;
|
||||
const BlobRange2& range = Range(it->second);
|
||||
HWY_ASSERT(range.offset != 0 && range.bytes != 0);
|
||||
HWY_ASSERT(range.End() <= file_bytes_);
|
||||
return ⦥
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return keys_.size(); }
|
||||
// Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that
|
||||
// everything else except `CallWithSpan` is in units of bytes.
|
||||
template <typename T>
|
||||
hwy::Span<const T> MappedSpan(const BlobRange2& range) const {
|
||||
HWY_ASSERT(IsMapped());
|
||||
HWY_ASSERT(range.bytes % sizeof(T) == 0);
|
||||
return hwy::Span<const T>(
|
||||
HWY_RCAST_ALIGNED(const T*, mapped_.get() + range.offset),
|
||||
range.bytes / sizeof(T));
|
||||
}
|
||||
|
||||
// Returns error, or calls `func(span)` with the blob identified by `key`.
|
||||
// This may allocate memory for the blob, and is intended for small blobs for
|
||||
// which an aligned allocation is unnecessary.
|
||||
template <typename T, class Func>
|
||||
bool CallWithSpan(const std::string& key, const Func& func) const {
|
||||
const BlobRange2* range = Find(key);
|
||||
if (!range) {
|
||||
HWY_WARN("Blob %s not found, sizeof T=%zu", key.c_str(), sizeof(T));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mode_ == Mode::kMap) {
|
||||
func(MappedSpan<T>(*range));
|
||||
return true;
|
||||
}
|
||||
|
||||
HWY_ASSERT(range->bytes % sizeof(T) == 0);
|
||||
std::vector<T> storage(range->bytes / sizeof(T));
|
||||
if (!file_->Read(range->offset, range->bytes, storage.data())) {
|
||||
HWY_WARN("Read failed for blob %s from %zu, size %zu; file %zu\n",
|
||||
key.c_str(), static_cast<size_t>(range->offset), range->bytes,
|
||||
static_cast<size_t>(file_bytes_));
|
||||
return false;
|
||||
}
|
||||
func(hwy::Span<const T>(storage.data(), storage.size()));
|
||||
return true;
|
||||
}
|
||||
|
||||
// The following methods must only be called if `!IsMapped()`.
|
||||
|
||||
// Enqueues a BlobIO2 for `ReadAll` to execute.
|
||||
void Enqueue(const BlobRange2& range, void* data);
|
||||
|
||||
// Reads in parallel all enqueued requests to the specified destinations.
|
||||
// Aborts on error.
|
||||
void ReadAll(hwy::ThreadPool& pool) const;
|
||||
|
||||
private:
|
||||
// Only for use by `Make`.
|
||||
BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||
const BlobStore& bs, Mode mode);
|
||||
|
||||
const std::unique_ptr<File> file_;
|
||||
const uint64_t file_bytes_;
|
||||
Mode mode_;
|
||||
|
||||
std::vector<std::string> keys_;
|
||||
std::vector<BlobRange2> ranges_;
|
||||
std::unordered_map<std::string, size_t> key_idx_for_key_;
|
||||
|
||||
MapPtr mapped_; // only if `kMap`
|
||||
std::vector<BlobIO2> requests_; // only if `kRead`
|
||||
};
|
||||
|
||||
// Collects references to blobs and writes them all at once with parallel I/O.
|
||||
// Thread-compatible: independent instances can be used concurrently, but it
|
||||
// does not make sense to call the methods concurrently.
|
||||
class BlobWriter2 {
|
||||
public:
|
||||
void Add(const std::string& key, const void* data, size_t bytes);
|
||||
|
||||
// For `ModelStore`: this is the `key_idx` of the next blob to be added.
|
||||
size_t NumAdded() const { return keys_.size(); }
|
||||
|
||||
// Stores all blobs to disk in the given order with padding for alignment.
|
||||
// Aborts on error.
|
||||
void WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
||||
|
||||
private:
|
||||
std::vector<hwy::uint128_t> keys_;
|
||||
|
|
|
|||
|
|
@ -19,9 +19,13 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
||||
|
||||
|
|
@ -32,8 +36,9 @@ namespace {
|
|||
class BlobStoreTest : public testing::Test {};
|
||||
#endif
|
||||
|
||||
#if !HWY_OS_WIN
|
||||
TEST(BlobStoreTest, TestReadWrite) {
|
||||
void TestWithMapped(Tristate map) {
|
||||
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
|
||||
|
||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||
|
||||
// mkstemp will modify path_str so it holds a newly-created temporary file.
|
||||
|
|
@ -41,44 +46,133 @@ TEST(BlobStoreTest, TestReadWrite) {
|
|||
const int fd = mkstemp(path_str);
|
||||
HWY_ASSERT(fd > 0);
|
||||
|
||||
hwy::ThreadPool pool(4);
|
||||
const Path path(path_str);
|
||||
std::array<float, 4> buffer = kOriginalData;
|
||||
|
||||
const hwy::uint128_t keyA = MakeKey("0123456789abcdef");
|
||||
const hwy::uint128_t keyB = MakeKey("q");
|
||||
BlobWriter writer;
|
||||
const std::string keyA("0123456789abcdef"); // max 16 characters
|
||||
const std::string keyB("q");
|
||||
BlobWriter2 writer;
|
||||
writer.Add(keyA, "DATA", 5);
|
||||
writer.Add(keyB, buffer.data(), sizeof(buffer));
|
||||
HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0);
|
||||
writer.WriteAll(pool, path);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
|
||||
std::fill(buffer.begin(), buffer.end(), 0);
|
||||
BlobReader reader;
|
||||
HWY_ASSERT_EQ(reader.Open(path), 0);
|
||||
HWY_ASSERT_EQ(reader.BlobSize(keyA), 5);
|
||||
HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer));
|
||||
|
||||
HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0);
|
||||
HWY_ASSERT_EQ(reader.ReadAll(pool), 0);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map);
|
||||
HWY_ASSERT(reader);
|
||||
|
||||
{
|
||||
std::array<char, 5> buffer;
|
||||
HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0);
|
||||
HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0);
|
||||
HWY_ASSERT_STRING_EQ("DATA", buffer.data());
|
||||
HWY_ASSERT_EQ(reader->Keys().size(), 2);
|
||||
HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str());
|
||||
HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str());
|
||||
|
||||
const BlobRange2* range = reader->Find(keyA);
|
||||
HWY_ASSERT(range);
|
||||
const uint64_t offsetA = range->offset;
|
||||
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
|
||||
HWY_ASSERT_EQ(range->bytes, 5);
|
||||
range = reader->Find(keyB);
|
||||
HWY_ASSERT(range);
|
||||
const uint64_t offsetB = range->offset;
|
||||
HWY_ASSERT_EQ(offsetB, 2 * 256);
|
||||
HWY_ASSERT_EQ(range->bytes, sizeof(buffer));
|
||||
|
||||
if (!reader->IsMapped()) {
|
||||
char str[5];
|
||||
reader->Enqueue(
|
||||
BlobRange2{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
|
||||
reader->Enqueue(
|
||||
BlobRange2{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
|
||||
buffer.data());
|
||||
reader->ReadAll(pool);
|
||||
HWY_ASSERT_STRING_EQ("DATA", str);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
const hwy::Span<const hwy::uint128_t> keys = reader.Keys();
|
||||
HWY_ASSERT_EQ(keys.size(), 2);
|
||||
HWY_ASSERT_EQ(keys[0], keyA);
|
||||
HWY_ASSERT_EQ(keys[1], keyB);
|
||||
HWY_ASSERT(
|
||||
reader->CallWithSpan<char>(keyA, [](const hwy::Span<const char> span) {
|
||||
HWY_ASSERT_EQ(span.size(), 5);
|
||||
HWY_ASSERT_STRING_EQ("DATA", span.data());
|
||||
}));
|
||||
HWY_ASSERT(
|
||||
reader->CallWithSpan<float>(keyB, [](const hwy::Span<const float> span) {
|
||||
HWY_ASSERT_EQ(span.size(), 4);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), span.data(), span.size());
|
||||
}));
|
||||
|
||||
close(fd);
|
||||
unlink(path_str);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(BlobStoreTest, TestReadWrite) {
|
||||
TestWithMapped(Tristate::kFalse);
|
||||
TestWithMapped(Tristate::kTrue);
|
||||
}
|
||||
|
||||
// Ensures padding works for any number of random-sized blobs.
|
||||
TEST(BlobStoreTest, TestNumBlobs) {
|
||||
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
|
||||
hwy::RandomState rng;
|
||||
|
||||
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
|
||||
// mkstemp will modify path_str so it holds a newly-created temporary file.
|
||||
char path_str[] = "/tmp/blob_store_test2.sbs-XXXXXX";
|
||||
const int fd = mkstemp(path_str);
|
||||
HWY_ASSERT(fd > 0);
|
||||
const Path path(path_str);
|
||||
|
||||
BlobWriter2 writer;
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(num_blobs);
|
||||
std::vector<std::vector<uint8_t>> blobs;
|
||||
blobs.reserve(num_blobs);
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
keys.push_back(std::to_string(i));
|
||||
// Smaller blobs when there are many, to speed up the test.
|
||||
const size_t mask = num_blobs > 1000 ? 1023 : 8191;
|
||||
// Never zero, but may be one byte, which we special-case.
|
||||
blobs.emplace_back((size_t{hwy::Random32(&rng)} & mask) + 1);
|
||||
std::vector<uint8_t>& blob = blobs.back();
|
||||
blob[0] = static_cast<uint8_t>(i & 255);
|
||||
if (blob.size() != 1) {
|
||||
blob.back() = static_cast<uint8_t>(i >> 8);
|
||||
}
|
||||
writer.Add(keys.back(), blob.data(), blob.size());
|
||||
}
|
||||
HWY_ASSERT(keys.size() == num_blobs);
|
||||
HWY_ASSERT(blobs.size() == num_blobs);
|
||||
writer.WriteAll(pool, path);
|
||||
|
||||
const Tristate map = Tristate::kFalse;
|
||||
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map);
|
||||
HWY_ASSERT(reader);
|
||||
HWY_ASSERT_EQ(reader->Keys().size(), num_blobs);
|
||||
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(),
|
||||
std::to_string(i).c_str());
|
||||
const BlobRange2* range = reader->Find(keys[i]);
|
||||
HWY_ASSERT(range);
|
||||
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
|
||||
HWY_ASSERT(reader->CallWithSpan<uint8_t>(
|
||||
keys[i], [path_str, num_blobs, i, range,
|
||||
&blobs](const hwy::Span<const uint8_t> span) {
|
||||
HWY_ASSERT_EQ(blobs[i].size(), span.size());
|
||||
const bool match1 = span[0] == static_cast<uint8_t>(i & 255);
|
||||
// If size == 1, we don't have a second byte to check.
|
||||
const bool match2 =
|
||||
span.size() == 1 ||
|
||||
span[span.size() - 1] == static_cast<uint8_t>(i >> 8);
|
||||
if (!match1 || !match2) {
|
||||
HWY_ABORT("%s num_blobs %zu blob %zu offset %zu is corrupted.",
|
||||
path_str, num_blobs, i, range->offset);
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
close(fd);
|
||||
unlink(path_str);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -24,11 +24,8 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/compress.h" // IWYU pragma: export
|
||||
#include "compression/distortion.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -520,17 +517,6 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Adapter that compresses into `MatStorageT`. `raw` must already be scaled
|
||||
// to fit the value range, if `Packed` is `SfpStream`.
|
||||
template <typename Packed>
|
||||
HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressWorkingSet& work,
|
||||
MatStorageT<Packed>& compressed,
|
||||
hwy::ThreadPool& pool) {
|
||||
Compress(raw, num, work, compressed.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
}
|
||||
|
||||
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
|
||||
// RMSNormInplace for the two output types.
|
||||
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
|
|
@ -712,49 +698,6 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
|||
comp3);
|
||||
}
|
||||
|
||||
// Functor called for each tensor, which compresses and stores them along with
|
||||
// their scaling factors to BlobStore.
|
||||
class Compressor {
|
||||
public:
|
||||
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||
const float* HWY_RESTRICT weights) {
|
||||
size_t num_weights = compressed->Extents().Area();
|
||||
if (num_weights == 0 || weights == nullptr || !compressed->HasPtr()) return;
|
||||
PackedSpan<Packed> packed = compressed->Span();
|
||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||
num_weights / (1000 * 1000));
|
||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
||||
writer_.pool());
|
||||
writer_(compressed, decorated_name);
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer) {
|
||||
writer_.AddTokenizer(tokenizer);
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
writer_.AddScales(scales, len);
|
||||
}
|
||||
|
||||
// Writes all blobs to disk in the given order. The config is optional and
|
||||
// if given, it is written to the file, along with the TOC, making it
|
||||
// single-file format. Otherwise, the file is written in the multi-file format
|
||||
// without a TOC.
|
||||
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||
return writer_.WriteAll(blob_filename, config);
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||
|
||||
private:
|
||||
CompressWorkingSet work_;
|
||||
WriteToBlobStore writer_;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -21,21 +21,15 @@
|
|||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#if COMPRESS_STATS
|
||||
#include <stdio.h>
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/fields.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/shared.h" // NuqStream::ClusterBuf
|
||||
#include "util/basics.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "compression/shared.h" // IWYU pragma: export
|
||||
#if COMPRESS_STATS
|
||||
#include "compression/distortion.h"
|
||||
#include "hwy/stats.h"
|
||||
|
|
@ -43,72 +37,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Table of contents for a blob store file. Full metadata, but not actual data.
|
||||
class BlobToc {
|
||||
public:
|
||||
BlobToc() = default;
|
||||
|
||||
// Loads the table of contents from the given reader.
|
||||
BlobError LoadToc(BlobReader& reader) {
|
||||
hwy::uint128_t toc_key = MakeKey(kTocName);
|
||||
size_t toc_size = reader.BlobSize(toc_key);
|
||||
if (toc_size != 0) {
|
||||
std::vector<uint32_t> toc(toc_size / sizeof(uint32_t));
|
||||
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read toc (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
size_t consumed = 0;
|
||||
size_t prev_consumed = static_cast<size_t>(-1);
|
||||
while (consumed < toc.size() && prev_consumed != consumed) {
|
||||
MatPtr blob;
|
||||
const IFields::ReadResult result =
|
||||
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
||||
prev_consumed = consumed;
|
||||
consumed = result.pos;
|
||||
if (!blob.IsEmpty()) {
|
||||
AddToToc(blob);
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool Empty() const { return toc_map_.empty(); }
|
||||
|
||||
// Returns true if the table of contents contains the given name.
|
||||
bool Contains(const std::string& name) const {
|
||||
return toc_map_.find(name) != toc_map_.end();
|
||||
}
|
||||
|
||||
// Returns the blob with the given name, or nullptr if not found.
|
||||
const MatPtr* Get(const std::string& name) const {
|
||||
auto it = toc_map_.find(name);
|
||||
if (it == toc_map_.end()) return nullptr;
|
||||
return &toc_[it->second];
|
||||
}
|
||||
// The name of the toc in the blob store file.
|
||||
static constexpr char kTocName[] = "toc";
|
||||
|
||||
// The name of the config in the blob store file.
|
||||
static constexpr char kConfigName[] = "config";
|
||||
|
||||
// The name of the tokenizer in the blob store file.
|
||||
static constexpr char kTokenizerName[] = "tokenizer";
|
||||
|
||||
private:
|
||||
// Adds the blob to the table of contents.
|
||||
void AddToToc(const MatPtr& blob) {
|
||||
HWY_ASSERT(!Contains(blob.Name()));
|
||||
toc_map_[blob.Name()] = toc_.size();
|
||||
toc_.push_back(blob);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, size_t> toc_map_;
|
||||
std::vector<MatPtr> toc_;
|
||||
};
|
||||
|
||||
#if COMPRESS_STATS
|
||||
class CompressStats {
|
||||
public:
|
||||
|
|
@ -176,199 +104,6 @@ struct CompressWorkingSet {
|
|||
std::vector<CompressPerThread> tls;
|
||||
};
|
||||
|
||||
// Class to collect and write a set of tensors to a blob store file.
|
||||
class WriteToBlobStore {
|
||||
public:
|
||||
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed,
|
||||
const char* decorated_name) const {
|
||||
if (!compressed->HasPtr()) return;
|
||||
writer_.Add(MakeKey(decorated_name), compressed->Packed(),
|
||||
compressed->PackedBytes());
|
||||
MatPtr renamed_tensor(*compressed);
|
||||
renamed_tensor.SetName(decorated_name);
|
||||
renamed_tensor.AppendTo(toc_);
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer) {
|
||||
writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(),
|
||||
tokenizer.size() * sizeof(tokenizer[0]));
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
if (len) {
|
||||
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
|
||||
writer_.Add(MakeKey(scales_ptr.Name()), scales, len * sizeof(scales[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Writes all blobs to disk in the given order. The config is optional and
|
||||
// if given, it is written to the file, along with the TOC, making it
|
||||
// single-file format. Otherwise, the file is written in the multi-file format
|
||||
// without a TOC.
|
||||
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||
if (config) {
|
||||
writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(),
|
||||
toc_.size() * sizeof(toc_[0]));
|
||||
config_buffer_ = config->Write();
|
||||
writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(),
|
||||
config_buffer_.size() * sizeof(config_buffer_[0]));
|
||||
}
|
||||
const BlobError err = writer_.WriteAll(pool_, blob_filename);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||
blob_filename.path.c_str(), err);
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||
|
||||
hwy::ThreadPool& pool() { return pool_; }
|
||||
|
||||
protected:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
private:
|
||||
mutable std::vector<uint32_t> toc_;
|
||||
mutable BlobWriter writer_;
|
||||
mutable std::vector<uint32_t> config_buffer_;
|
||||
};
|
||||
|
||||
// Functor called for each tensor, which loads them and their scaling factors
|
||||
// from BlobStore.
|
||||
class ReadFromBlobStore {
|
||||
public:
|
||||
explicit ReadFromBlobStore(const Path& blob_filename) {
|
||||
err_ = reader_.Open(blob_filename);
|
||||
if (HWY_UNLIKELY(err_ != 0)) {
|
||||
fprintf(stderr, "Error %d opening BlobStore %s.\n", err_,
|
||||
blob_filename.path.c_str());
|
||||
return; // avoid overwriting err_ to ensure ReadAll will fail.
|
||||
}
|
||||
err_ = file_toc_.LoadToc(reader_);
|
||||
if (HWY_UNLIKELY(err_ != 0)) {
|
||||
fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if there is a TOC.
|
||||
bool HaveToc() const { return !file_toc_.Empty(); }
|
||||
|
||||
// Reads the config from the blob store file.
|
||||
BlobError LoadConfig(ModelConfig& config) {
|
||||
hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName);
|
||||
size_t config_size = reader_.BlobSize(config_key);
|
||||
if (config_size == 0) return __LINE__;
|
||||
std::vector<uint32_t> config_buffer(config_size / sizeof(uint32_t));
|
||||
BlobError err =
|
||||
reader_.ReadOne(config_key, config_buffer.data(), config_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read config (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
config.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Reads the tokenizer from the blob store file.
|
||||
BlobError LoadTokenizer(std::string& tokenizer) {
|
||||
hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName);
|
||||
size_t tokenizer_size = reader_.BlobSize(key);
|
||||
if (tokenizer_size == 0) return __LINE__;
|
||||
tokenizer.resize(tokenizer_size);
|
||||
;
|
||||
BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read tokenizer (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Called for each tensor, enqueues read requests.
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
||||
HWY_ASSERT(tensors[0]);
|
||||
model_toc_.push_back(tensors[0]);
|
||||
file_keys_.push_back(name);
|
||||
}
|
||||
}
|
||||
|
||||
BlobError LoadScales(float* scales, size_t len) {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
scales[i] = 1.0f;
|
||||
}
|
||||
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
|
||||
auto key = MakeKey(scales_ptr.Name());
|
||||
if (reader_.BlobSize(key) == 0) return 0;
|
||||
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
|
||||
}
|
||||
|
||||
// Returns whether all tensors are successfully loaded from cache.
|
||||
BlobError ReadAll(hwy::ThreadPool& pool,
|
||||
std::vector<MatOwner>& model_memory) {
|
||||
// reader_ invalid or any Enqueue failed
|
||||
if (err_ != 0) return err_;
|
||||
// Setup the model_memory.
|
||||
for (size_t b = 0; b < model_toc_.size(); ++b) {
|
||||
const std::string& file_key = file_keys_[b];
|
||||
MatPtr* blob = model_toc_[b];
|
||||
if (!file_toc_.Empty()) {
|
||||
const MatPtr* toc_blob = file_toc_.Get(file_key);
|
||||
if (toc_blob == nullptr) {
|
||||
fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
if (toc_blob->Rows() != blob->Rows() ||
|
||||
toc_blob->Cols() != blob->Cols()) {
|
||||
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
std::string name = blob->Name();
|
||||
*blob = *toc_blob;
|
||||
blob->SetName(name.c_str());
|
||||
}
|
||||
model_memory.push_back(MatOwner());
|
||||
}
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_memory.size(),
|
||||
[this, &model_memory](uint64_t task, size_t /*thread*/) {
|
||||
model_memory[task].AllocateFor(*model_toc_[task],
|
||||
MatPadding::kPacked);
|
||||
});
|
||||
// Enqueue the read requests.
|
||||
for (size_t b = 0; b < model_toc_.size(); ++b) {
|
||||
err_ = reader_.Enqueue(MakeKey(file_keys_[b].c_str()),
|
||||
model_toc_[b]->RowT<uint8_t>(0),
|
||||
model_toc_[b]->PackedBytes());
|
||||
if (err_ != 0) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"Failed to read blob %s (error %d) of size %zu x %zu, type %d\n",
|
||||
file_keys_[b].c_str(), err_, model_toc_[b]->Rows(),
|
||||
model_toc_[b]->Cols(), static_cast<int>(model_toc_[b]->GetType()));
|
||||
return err_;
|
||||
}
|
||||
}
|
||||
return reader_.ReadAll(pool);
|
||||
}
|
||||
|
||||
private:
|
||||
BlobReader reader_;
|
||||
BlobError err_ = 0;
|
||||
// Table of contents from the file, if present.
|
||||
BlobToc file_toc_;
|
||||
// Table of contents from the model. Pointers to original MatPtrT so the
|
||||
// data pointers can be updated.
|
||||
std::vector<MatPtr*> model_toc_;
|
||||
// Mangled names of the tensors in model_toc_ for reading from the file.
|
||||
std::vector<std::string> file_keys_;
|
||||
};
|
||||
|
||||
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
|
||||
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
|
||||
// multiplier with which to restore the original values. This is only necessary
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ struct TestDecompress2T {
|
|||
stats.Notify(raw[i], hwy::ConvertScalarTo<float>(dec[i]));
|
||||
}
|
||||
|
||||
if constexpr (false) {
|
||||
if constexpr (true) { // leave enabled due to sporadic failures
|
||||
fprintf(stderr,
|
||||
"TypeName<Packed>() %s TypeName<T>() %s: num %zu: stats.SumL1() "
|
||||
"%f stats.GeomeanValueDivL1() %f stats.WeightedAverageL1() %f "
|
||||
|
|
|
|||
|
|
@ -1,209 +0,0 @@
|
|||
# Copyright 2024 Google LLC
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Converts pytorch to f32 for use by compress_weights.cc."""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
from gemma import config
|
||||
from gemma import model as gemma_model
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Requires torch 2.2 and gemma package from
|
||||
# https://github.com/google/gemma_pytorch
|
||||
|
||||
|
||||
def check_file_exists(value):
|
||||
if not os.path.exists(str(value)):
|
||||
raise argparse.ArgumentTypeError(
|
||||
"The file %s does not appear to exist." % value
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def check_model_types(value):
|
||||
if str(value).lower() not in ["2b", "7b"]:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Model type value %s is not in [2b, 7b]." % value
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
dest="tokenizer",
|
||||
default="models/tokenizer.spm",
|
||||
help="Location of tokenizer file (.model or .spm)",
|
||||
type=check_file_exists,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--weights",
|
||||
dest="weights",
|
||||
default="models/gemma-2b-it.ckpt",
|
||||
help="Location of input checkpoint file (.ckpt)",
|
||||
type=check_file_exists,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
dest="output_file",
|
||||
default="2bit-f32.sbs",
|
||||
help="Location to write converted weights",
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
dest="model_type",
|
||||
default="2b",
|
||||
help="Model size / type (2b, 7b)",
|
||||
type=check_model_types,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
TRANSFORMATIONS = {
|
||||
"2b": collections.defaultdict(
|
||||
lambda: lambda x: x,
|
||||
{
|
||||
"embedder.weight": lambda x: x,
|
||||
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(
|
||||
(2048, 8, 256)
|
||||
).transpose([1, 0, 2]),
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
},
|
||||
),
|
||||
"7b": collections.defaultdict(
|
||||
lambda: lambda x: x,
|
||||
{
|
||||
"embedder.weight": lambda x: x,
|
||||
"self_attn.qkv_proj.weight": lambda x: x.reshape(
|
||||
(3, 16, 256, 3072)
|
||||
).transpose([1, 0, 2, 3]),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(
|
||||
(3072, 16, 256)
|
||||
).transpose([1, 0, 2]),
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
VALIDATIONS = {
|
||||
"2b": {
|
||||
"embedder.weight": lambda x: x.shape == (256000, 2048),
|
||||
"model.norm.weight": lambda x: x.shape == (2048,),
|
||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
|
||||
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
||||
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
|
||||
"input_layernorm.weight": lambda x: x.shape == (2048,),
|
||||
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
||||
},
|
||||
"7b": {
|
||||
"embedder.weight": lambda x: x.shape == (256000, 3072),
|
||||
"model.norm.weight": lambda x: x.shape == (3072,),
|
||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
|
||||
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
|
||||
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
|
||||
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
|
||||
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
|
||||
"input_layernorm.weight": lambda x: x.shape == (3072,),
|
||||
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def param_names(num_hidden_layers: int):
|
||||
"""Return parameter names in the order they are expected for deserialization."""
|
||||
|
||||
# note *weight_scaler params are ignored in the forward computation unless
|
||||
# quantization is being used.
|
||||
#
|
||||
# since we are working with the full precision weights as input, don't
|
||||
# include these in the parameters being iterated over.
|
||||
|
||||
names = [
|
||||
("embedder.weight",) * 2, # embedder_input_embedding
|
||||
("model.norm.weight",) * 2, # final_norm_scale
|
||||
]
|
||||
layer_params = [
|
||||
"self_attn.o_proj.weight", # attn_vec_einsum_w
|
||||
"self_attn.qkv_proj.weight", # qkv_einsum_w
|
||||
"mlp.gate_proj.weight", # gating_einsum_w
|
||||
"mlp.up_proj.weight",
|
||||
"mlp.down_proj.weight", # linear_w
|
||||
"input_layernorm.weight", # pre_attention_norm_scale
|
||||
"post_attention_layernorm.weight", # pre_ffw_norm_scale
|
||||
]
|
||||
for layer in range(num_hidden_layers):
|
||||
for layer_param in layer_params:
|
||||
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
|
||||
return names
|
||||
|
||||
|
||||
def convert_weights():
|
||||
"""Main function; loads weights, runs transformations, writes f32."""
|
||||
model_type = args.model_type
|
||||
output_file = args.output_file
|
||||
|
||||
model_config = config.get_model_config(model_type)
|
||||
model_config.dtype = "float32"
|
||||
model_config.tokenizer = args.tokenizer
|
||||
device = torch.device("cpu")
|
||||
torch.set_default_dtype(torch.float)
|
||||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
|
||||
model.load_weights(args.weights)
|
||||
model.to(device).eval()
|
||||
|
||||
model_dict = dict(model.named_parameters())
|
||||
param_order = param_names(model_config.num_hidden_layers)
|
||||
|
||||
all_ok = True
|
||||
print("Checking transformations ...")
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[model_type][layer_name](arr)
|
||||
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
|
||||
|
||||
if check == "FAILED":
|
||||
all_ok = False
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
|
||||
if all_ok:
|
||||
print("Writing parameters ...")
|
||||
with open(output_file, "wb") as bin_handle:
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[model_type][layer_name](arr)
|
||||
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_weights()
|
||||
print("Done")
|
||||
|
|
@ -56,8 +56,7 @@ struct IFields; // breaks circular dependency
|
|||
// because their `IFields::VisitFields` calls `visitor.operator()`.
|
||||
//
|
||||
// Supported field types `T`: `uint32_t`, `int32_t`, `uint64_t`, `float`,
|
||||
// `std::string`,
|
||||
// classes derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
||||
// `std::string`, `IFields` subclasses, `bool`, `enum`, `std::vector<T>`.
|
||||
class IFieldsVisitor {
|
||||
public:
|
||||
virtual ~IFieldsVisitor();
|
||||
|
|
|
|||
|
|
@ -13,11 +13,9 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
// Loads a model and saves it in single-file format.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "evals/benchmark_helper.h" // GemmaEnv
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/args.h"
|
||||
|
||||
|
|
@ -25,18 +23,9 @@ namespace gcpp {
|
|||
namespace {
|
||||
|
||||
struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||
// --output_weights is required.
|
||||
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (output_weights.path.empty()) {
|
||||
return "Missing --output_weights flag, a file for the model weights.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path output_weights; // weights file location
|
||||
Path output_weights;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
|
|
@ -49,14 +38,12 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
|
|||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Loads a model in the multi-file format and saves it in single-file format.
|
||||
gcpp::WriterArgs args(argc, argv);
|
||||
if (const char* err = args.Validate()) {
|
||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||
return 1;
|
||||
if (args.output_weights.Empty()) {
|
||||
HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
|
||||
}
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
hwy::ThreadPool pool(0);
|
||||
env.GetGemma()->Save(args.output_weights, pool);
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools.Pool());
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,11 +14,14 @@ cc_library(
|
|||
hdrs = ["compression_clif_aux.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:common",
|
||||
"//:basics",
|
||||
"//:configs",
|
||||
"//:mat",
|
||||
"//:model_store",
|
||||
"//:tensor_info",
|
||||
"//:threading_context",
|
||||
"//:tokenizer",
|
||||
"//:weights",
|
||||
"//compression:blob_store",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -31,7 +34,8 @@ pybind_extension(
|
|||
srcs = ["compression_extension.cc"],
|
||||
deps = [
|
||||
":compression_clif_aux",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:mat",
|
||||
"//:tensor_info",
|
||||
"//compression:shared",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,15 +15,24 @@
|
|||
|
||||
#include "compression/python/compression_clif_aux.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "compression/blob_store.h" // BlobWriter2
|
||||
#include "compression/compress.h" // ScaleWeights
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
#include "gemma/tensor_info.h" // TensorInfo
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
|
|
@ -33,151 +42,92 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
|
||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||
// compile pass, whereas we want this defined in the first.
|
||||
#ifndef GEMMA_ONCE
|
||||
#define GEMMA_ONCE
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/io.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class WriterInterface {
|
||||
public:
|
||||
virtual ~WriterInterface() = default;
|
||||
|
||||
virtual void Insert(std::string name, absl::Span<const float> weights,
|
||||
Type type, const TensorInfo& tensor_info,
|
||||
float scale) = 0;
|
||||
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
|
||||
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
||||
virtual void InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void InsertFloat(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||
virtual void AddTokenizer(const std::string& tokenizer_path) = 0;
|
||||
|
||||
virtual size_t DebugNumBlobsAdded() const = 0;
|
||||
|
||||
virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // GEMMA_ONCE
|
||||
|
||||
// SIMD code, compiled once per target.
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
class SbsWriterImpl : public WriterInterface {
|
||||
// Implementation for the currently compiled SIMD target.
|
||||
class SbsWriterImpl : public ISbsWriter {
|
||||
template <typename Packed>
|
||||
void AllocateAndCompress(const std::string& name,
|
||||
absl::Span<const float> weights) {
|
||||
MatPtrT<Packed> storage(name.c_str(), Extents2D(1, weights.size()));
|
||||
model_memory_.push_back(MatOwner());
|
||||
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
|
||||
std::string decorated_name = CacheName(storage);
|
||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||
}
|
||||
template <typename Packed>
|
||||
void AllocateWithShape(const std::string& name,
|
||||
absl::Span<const float> weights,
|
||||
const TensorInfo& tensor_info, float scale) {
|
||||
MatPtrT<Packed> storage(name.c_str(), &tensor_info);
|
||||
storage.SetScale(scale);
|
||||
void InsertT(const char* name, F32Span weights,
|
||||
const TensorInfo& tensor_info) {
|
||||
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
|
||||
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
|
||||
// and depending on the input values may require rescaling. Scaling is
|
||||
// cheap for matmul and probably not an issue for other ops, but it might be
|
||||
// beneficial for precision to keep the original data range for other types.
|
||||
if (mat.GetType() == Type::kSFP || mat.GetType() == Type::kNUQ) {
|
||||
mat.SetScale(ScaleWeights(weights.data(), weights.size()));
|
||||
}
|
||||
|
||||
model_memory_.push_back(MatOwner());
|
||||
if (mode_ == CompressorMode::kTEST_ONLY) return;
|
||||
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
|
||||
std::string decorated_name = CacheName(storage);
|
||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||
if (weights.size() == 0) {
|
||||
HWY_WARN("Ignoring zero-sized tensor %s.", name);
|
||||
return;
|
||||
}
|
||||
|
||||
mat.AppendTo(serialized_mat_ptrs_);
|
||||
mat_owners_.AllocateFor(mat, MatPadding::kPacked);
|
||||
|
||||
// Handle gemma_export_test's MockArray. Write blobs so that the test
|
||||
// succeeds, but we only have 10 floats, not the full tensor.
|
||||
if (weights.size() == 10 && mat.Extents().Area() != 10) {
|
||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||
/*packed_ofs=*/0, pool_);
|
||||
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Compressing %s (%zu x %zu = %zuM) to %s, please wait\n",
|
||||
name, mat.Rows(), mat.Cols(), weights.size() / (1000 * 1000),
|
||||
TypeName(TypeEnum<Packed>()));
|
||||
HWY_ASSERT(weights.size() == mat.Extents().Area());
|
||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||
/*packed_ofs=*/0, pool_);
|
||||
writer_.Add(name, mat.Packed(), mat.PackedBytes());
|
||||
}
|
||||
|
||||
public:
|
||||
explicit SbsWriterImpl(CompressorMode mode)
|
||||
: pool_(0), compressor_(pool_), mode_(mode) {}
|
||||
SbsWriterImpl() : pool_(ThreadingContext2::Get().pools.Pool()) {}
|
||||
|
||||
void Insert(std::string name, absl::Span<const float> weights, Type type,
|
||||
const TensorInfo& tensor_info, float scale) override {
|
||||
void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) override {
|
||||
switch (type) {
|
||||
case Type::kSFP:
|
||||
AllocateWithShape<SfpStream>(name, weights, tensor_info, scale);
|
||||
InsertT<SfpStream>(name, weights, tensor_info);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
AllocateWithShape<NuqStream>(name, weights, tensor_info, scale);
|
||||
InsertT<NuqStream>(name, weights, tensor_info);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
AllocateWithShape<BF16>(name, weights, tensor_info, scale);
|
||||
InsertT<BF16>(name, weights, tensor_info);
|
||||
break;
|
||||
case Type::kF32:
|
||||
AllocateWithShape<float>(name, weights, tensor_info, scale);
|
||||
InsertT<float>(name, weights, tensor_info);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Unsupported type");
|
||||
HWY_ABORT("Unsupported destination (compressed) type %s",
|
||||
TypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
void InsertSfp(std::string name, absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<SfpStream>(name, weights);
|
||||
void Write(const ModelConfig& config, const std::string& tokenizer_path,
|
||||
const std::string& path) override {
|
||||
const GemmaTokenizer tokenizer(
|
||||
tokenizer_path.empty() ? kMockTokenizer
|
||||
: ReadFileToString(Path(tokenizer_path)));
|
||||
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, pool_,
|
||||
gcpp::Path(path));
|
||||
}
|
||||
|
||||
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<NuqStream>(name, weights);
|
||||
}
|
||||
|
||||
void InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<BF16>(name, weights);
|
||||
}
|
||||
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<float>(name, weights);
|
||||
}
|
||||
|
||||
void AddScales(const std::vector<float>& scales) override {
|
||||
HWY_ASSERT(scales_.empty());
|
||||
scales_ = scales;
|
||||
compressor_.AddScales(scales_.data(), scales_.size());
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer_path) override {
|
||||
Path path(tokenizer_path);
|
||||
GemmaTokenizer tokenizer(path);
|
||||
std::string tokenizer_proto = tokenizer.Serialize();
|
||||
HWY_ASSERT(!tokenizer_proto.empty());
|
||||
compressor_.AddTokenizer(tokenizer_proto);
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const {
|
||||
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
|
||||
return compressor_.DebugNumBlobsAdded();
|
||||
}
|
||||
|
||||
int WriteWithConfig(std::string path, const ModelConfig* config) override {
|
||||
return compressor_.WriteAll(gcpp::Path(path), config);
|
||||
}
|
||||
|
||||
hwy::ThreadPool pool_;
|
||||
Compressor compressor_;
|
||||
hwy::ThreadPool& pool_;
|
||||
MatOwners mat_owners_;
|
||||
CompressWorkingSet working_set_;
|
||||
std::vector<MatOwner> model_memory_;
|
||||
std::vector<float> scales_;
|
||||
CompressorMode mode_;
|
||||
BlobWriter2 writer_;
|
||||
std::vector<uint32_t> serialized_mat_ptrs_;
|
||||
};
|
||||
|
||||
WriterInterface* NewSbsWriter(CompressorMode mode) {
|
||||
return new SbsWriterImpl(mode);
|
||||
}
|
||||
ISbsWriter* NewSbsWriter() { return new SbsWriterImpl; }
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
@ -188,43 +138,10 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(NewSbsWriter);
|
||||
|
||||
SbsWriter::SbsWriter(CompressorMode mode)
|
||||
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(mode)) {}
|
||||
SbsWriter::~SbsWriter() = default;
|
||||
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
||||
|
||||
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
|
||||
Type type, const TensorInfo& tensor_info, float scale) {
|
||||
impl_->Insert(name, weights, type, tensor_info, scale);
|
||||
}
|
||||
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertSfp(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertNUQ(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) {
|
||||
impl_->InsertBfloat16(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertFloat(name, weights);
|
||||
}
|
||||
|
||||
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||
impl_->AddScales(scales);
|
||||
}
|
||||
|
||||
void SbsWriter::AddTokenizer(const std::string& tokenizer_path) {
|
||||
impl_->AddTokenizer(tokenizer_path);
|
||||
}
|
||||
|
||||
size_t SbsWriter::DebugNumBlobsAdded() const {
|
||||
return impl_->DebugNumBlobsAdded();
|
||||
}
|
||||
|
||||
int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) {
|
||||
return impl_->WriteWithConfig(path, config);
|
||||
}
|
||||
SbsReader::SbsReader(const std::string& path)
|
||||
: reader_(gcpp::BlobReader2::Make(Path(path))), model_(*reader_) {}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -16,52 +16,69 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <stddef.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/shared.h"
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// How to process the data.
|
||||
enum class CompressorMode {
|
||||
// No compression, no write to file, just for testing.
|
||||
kTEST_ONLY,
|
||||
// Old-style compression, no table of contents.
|
||||
kNO_TOC,
|
||||
// New-style compression, with table of contents.
|
||||
kWITH_TOC,
|
||||
// Can be modified in place by ScaleWeights.
|
||||
using F32Span = hwy::Span<float>;
|
||||
|
||||
// Interface because we compile one derived implementation per SIMD target,
|
||||
// because Compress() uses SIMD.
|
||||
class ISbsWriter {
|
||||
public:
|
||||
virtual ~ISbsWriter() = default;
|
||||
|
||||
virtual void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) = 0;
|
||||
|
||||
virtual void Write(const ModelConfig& config,
|
||||
const std::string& tokenizer_path,
|
||||
const std::string& path) = 0;
|
||||
};
|
||||
|
||||
class WriterInterface;
|
||||
|
||||
// Non-virtual class used by pybind that calls the interface's virtual methods.
|
||||
// This avoids having to register the derived types with pybind.
|
||||
class SbsWriter {
|
||||
public:
|
||||
explicit SbsWriter(CompressorMode mode);
|
||||
~SbsWriter();
|
||||
SbsWriter();
|
||||
|
||||
void Insert(std::string name, absl::Span<const float> weights, Type type,
|
||||
const TensorInfo& tensor_info, float scale);
|
||||
void InsertSfp(std::string name, absl::Span<const float> weights);
|
||||
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
||||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||
void AddScales(const std::vector<float>& scales);
|
||||
void AddTokenizer(const std::string& tokenizer_path);
|
||||
void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) {
|
||||
impl_->Insert(name, weights, type, tensor_info);
|
||||
}
|
||||
|
||||
size_t DebugNumBlobsAdded() const;
|
||||
|
||||
int Write(std::string path) { return WriteWithConfig(path, nullptr); }
|
||||
int WriteWithConfig(std::string path, const ModelConfig* config);
|
||||
void Write(const ModelConfig& config, const std::string& tokenizer_path,
|
||||
const std::string& path) {
|
||||
impl_->Write(config, tokenizer_path, path);
|
||||
}
|
||||
|
||||
private:
|
||||
// Isolates Highway-dispatched types and other internals from CLIF.
|
||||
std::unique_ptr<WriterInterface> impl_;
|
||||
std::unique_ptr<ISbsWriter> impl_;
|
||||
};
|
||||
|
||||
// Limited metadata-only reader for tests.
|
||||
class SbsReader {
|
||||
public:
|
||||
SbsReader(const std::string& path);
|
||||
|
||||
const ModelConfig& Config() const { return model_.Config(); }
|
||||
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<gcpp::BlobReader2> reader_;
|
||||
gcpp::ModelStore2 model_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -15,58 +15,55 @@
|
|||
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/python/compression_clif_aux.h"
|
||||
#include "compression/shared.h"
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "util/mat.h"
|
||||
|
||||
using gcpp::CompressorMode;
|
||||
using gcpp::MatPtr;
|
||||
using gcpp::SbsReader;
|
||||
using gcpp::SbsWriter;
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace pybind11 {
|
||||
|
||||
namespace {
|
||||
template <auto Func>
|
||||
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
|
||||
static void CallWithF32Span(SbsWriter& writer, const char* name,
|
||||
array_t<float> data, gcpp::Type type,
|
||||
const gcpp::TensorInfo& tensor_info) {
|
||||
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||
HWY_ABORT("Input array must be 1D (not %d) and contiguous floats.",
|
||||
static_cast<int>(data.ndim()));
|
||||
}
|
||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
|
||||
std::invoke(Func, writer, name,
|
||||
gcpp::F32Span(data.mutable_data(0), data.size()), type,
|
||||
tensor_info);
|
||||
}
|
||||
template <auto Func>
|
||||
void wrap_span_typed(SbsWriter& writer, std::string name,
|
||||
py::array_t<float> data, gcpp::Type type,
|
||||
gcpp::TensorInfo tensor_info, float scale) {
|
||||
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||
}
|
||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
|
||||
type, tensor_info, scale);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(compression, m) {
|
||||
py::enum_<CompressorMode>(m, "CompressorMode")
|
||||
.value("TEST_ONLY", CompressorMode::kTEST_ONLY)
|
||||
.value("NO_TOC", CompressorMode::kNO_TOC)
|
||||
.value("WITH_TOC", CompressorMode::kWITH_TOC);
|
||||
class_<SbsWriter>(m, "SbsWriter")
|
||||
.def(init<>())
|
||||
.def("insert", CallWithF32Span<&SbsWriter::Insert>)
|
||||
.def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"),
|
||||
arg("path"));
|
||||
|
||||
py::class_<SbsWriter>(m, "SbsWriter")
|
||||
.def(py::init<CompressorMode>())
|
||||
// NOTE: Individual compression backends may impose constraints on the
|
||||
// array length, such as a minimum of (say) 32 elements.
|
||||
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
|
||||
.def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>)
|
||||
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
|
||||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||
.def("add_scales", &SbsWriter::AddScales)
|
||||
.def("add_tokenizer", &SbsWriter::AddTokenizer)
|
||||
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
|
||||
.def("write", &SbsWriter::Write)
|
||||
.def("write_with_config", &SbsWriter::WriteWithConfig);
|
||||
class_<MatPtr>(m, "MatPtr")
|
||||
// No init, only created within C++.
|
||||
.def_property_readonly("rows", &MatPtr::Rows, "Number of rows")
|
||||
.def_property_readonly("cols", &MatPtr::Cols, "Number of cols")
|
||||
.def_property_readonly("type", &MatPtr::GetType, "Element type")
|
||||
.def_property_readonly("scale", &MatPtr::Scale, "Scaling factor");
|
||||
|
||||
class_<SbsReader>(m, "SbsReader")
|
||||
.def(init<std::string>())
|
||||
.def_property_readonly("config", &SbsReader::Config,
|
||||
return_value_policy::reference_internal,
|
||||
"ModelConfig")
|
||||
.def("find_mat", &SbsReader::FindMat,
|
||||
return_value_policy::reference_internal,
|
||||
"Returns MatPtr for given name.");
|
||||
}
|
||||
|
||||
} // namespace pybind11
|
||||
|
|
|
|||
|
|
@ -25,46 +25,132 @@ from python import configs
|
|||
class CompressionTest(absltest.TestCase):
|
||||
|
||||
def test_sbs_writer(self):
|
||||
temp_file = self.create_tempfile("test.sbs")
|
||||
tensor_info = configs.TensorInfo()
|
||||
tensor_info.name = "foo"
|
||||
tensor_info.axes = [0]
|
||||
tensor_info.shape = [192]
|
||||
info_192 = configs.TensorInfo()
|
||||
info_192.name = "ignored_192"
|
||||
info_192.axes = [0]
|
||||
info_192.shape = [192]
|
||||
|
||||
writer = compression.SbsWriter(compression.CompressorMode.NO_TOC)
|
||||
writer = compression.SbsWriter()
|
||||
writer.insert(
|
||||
"foo",
|
||||
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
|
||||
"tensor0",
|
||||
# Large enough to require scaling.
|
||||
np.array([3.0012] * 128 + [4.001] * 64, dtype=np.float32),
|
||||
configs.Type.kSFP,
|
||||
tensor_info,
|
||||
1.0,
|
||||
info_192,
|
||||
)
|
||||
|
||||
tensor_info_nuq = configs.TensorInfo()
|
||||
tensor_info_nuq.name = "fooNUQ"
|
||||
tensor_info_nuq.axes = [0]
|
||||
tensor_info_nuq.shape = [256]
|
||||
# 2D tensor.
|
||||
info_2d = configs.TensorInfo()
|
||||
info_2d.name = "ignored_2d"
|
||||
info_2d.axes = [0, 1]
|
||||
info_2d.shape = [96, 192]
|
||||
writer.insert(
|
||||
"fooNUQ",
|
||||
"tensor_2d",
|
||||
np.array([i / 1e3 for i in range(96 * 192)], dtype=np.float32),
|
||||
configs.Type.kBF16,
|
||||
info_2d,
|
||||
)
|
||||
|
||||
# 3D collapsed into rows.
|
||||
info_3d = configs.TensorInfo()
|
||||
info_3d.name = "ignored_3d"
|
||||
info_3d.axes = [0, 1, 2]
|
||||
info_3d.shape = [10, 12, 192]
|
||||
info_3d.cols_take_extra_dims = False
|
||||
writer.insert(
|
||||
"tensor_3d",
|
||||
# Verification of scale below depends on the shape and multiplier here.
|
||||
np.array([i / 1e3 for i in range(10 * 12 * 192)], dtype=np.float32),
|
||||
configs.Type.kSFP,
|
||||
info_3d,
|
||||
)
|
||||
|
||||
# Exercise all types supported by Compress.
|
||||
info_256 = configs.TensorInfo()
|
||||
info_256.name = "ignored_256"
|
||||
info_256.axes = [0]
|
||||
info_256.shape = [256]
|
||||
writer.insert(
|
||||
"tensor_nuq",
|
||||
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
|
||||
configs.Type.kNUQ,
|
||||
tensor_info_nuq,
|
||||
1.0,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_sfp(
|
||||
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
||||
writer.insert(
|
||||
"tensor_sfp",
|
||||
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
|
||||
configs.Type.kSFP,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_nuq(
|
||||
"baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32)
|
||||
writer.insert(
|
||||
"tensor_bf",
|
||||
np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32),
|
||||
configs.Type.kBF16,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_bf16(
|
||||
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
|
||||
writer.insert(
|
||||
"tensor_f32",
|
||||
np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32),
|
||||
configs.Type.kF32,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_float(
|
||||
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
||||
|
||||
config = configs.ModelConfig(
|
||||
configs.Model.GEMMA_TINY,
|
||||
configs.Type.kNUQ,
|
||||
configs.PromptWrapping.GEMMA_IT,
|
||||
)
|
||||
self.assertEqual(writer.debug_num_blobs_added(), 6)
|
||||
self.assertEqual(writer.write(temp_file.full_path), 0)
|
||||
tokenizer_path = "" # no tokenizer required for testing
|
||||
temp_file = self.create_tempfile("test.sbs")
|
||||
writer.write(config, tokenizer_path, temp_file.full_path)
|
||||
|
||||
print("Ignore next two warnings; test does not enable model deduction.")
|
||||
reader = compression.SbsReader(temp_file.full_path)
|
||||
|
||||
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
|
||||
self.assertEqual(reader.config.weight, configs.Type.kNUQ)
|
||||
|
||||
mat = reader.find_mat("tensor0")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 4.001 / 1.875, places=5)
|
||||
|
||||
mat = reader.find_mat("tensor_2d")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
self.assertEqual(mat.rows, 96)
|
||||
self.assertEqual(mat.type, configs.Type.kBF16)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_3d")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
self.assertEqual(mat.rows, 10 * 12)
|
||||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2)
|
||||
|
||||
mat = reader.find_mat("tensor_nuq")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kNUQ)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_sfp")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_bf")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kBF16)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_f32")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kF32)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -165,6 +165,8 @@ constexpr bool IsNuqStream() {
|
|||
// `WeightsPtrs`. When adding a new type that is supported, also
|
||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||
// change, though new ones may be added.
|
||||
static constexpr const char* kTypeStrings[] = {
|
||||
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"};
|
||||
static constexpr size_t kNumTypes =
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility> // std::pair
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
|
|
@ -26,7 +25,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
|||
public:
|
||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
Path goldens;
|
||||
Path summarize_text;
|
||||
Path cross_entropy;
|
||||
Path trivia_qa;
|
||||
|
|
@ -35,8 +33,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
|||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(goldens.path, "goldens_dir", std::string(""),
|
||||
"Directory containing golden files", 2);
|
||||
visitor(summarize_text.path, "summarize_text", std::string(""),
|
||||
"Path to text file to summarize", 2);
|
||||
visitor(cross_entropy.path, "cross_entropy", std::string(""),
|
||||
|
|
@ -52,56 +48,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> load_goldens(
|
||||
const std::string& path) {
|
||||
std::ifstream goldens_file(path);
|
||||
if (!goldens_file) {
|
||||
std::cout << "Could not load goldens file: " << path << "\n" << std::flush;
|
||||
return {};
|
||||
}
|
||||
std::vector<std::pair<std::string, std::string>> res;
|
||||
std::string query_separator;
|
||||
std::string query;
|
||||
std::string answer_separator;
|
||||
std::string answer;
|
||||
while (std::getline(goldens_file, query_separator) &&
|
||||
std::getline(goldens_file, query) &&
|
||||
std::getline(goldens_file, answer_separator) &&
|
||||
std::getline(goldens_file, answer)) {
|
||||
res.push_back({query, answer});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
|
||||
std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||
load_goldens(golden_path);
|
||||
size_t correct_answers = 0;
|
||||
size_t total_tokens = 0;
|
||||
const double time_start = hwy::platform::Now();
|
||||
for (auto& [question, expected_answer] : queries_answers) {
|
||||
QueryResult result = env.QueryModel(question);
|
||||
total_tokens += result.tokens_generated;
|
||||
if (result.response.find(expected_answer) != std::string::npos) {
|
||||
correct_answers++;
|
||||
} else {
|
||||
std::cout << "Wrong!\n";
|
||||
std::cout << "Input: " << question << "\n";
|
||||
std::cout << "Expected: " << expected_answer << "\n";
|
||||
std::cout << "Output: " << result.response << "\n\n" << std::flush;
|
||||
}
|
||||
}
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
||||
std::cout << "Correct: " << correct_answers << " out of "
|
||||
<< queries_answers.size() << "\n"
|
||||
<< std::flush;
|
||||
if (correct_answers != queries_answers.size()) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
||||
std::string prompt("Here is some text to summarize:\n");
|
||||
prompt.append(ReadFileToString(text));
|
||||
|
|
@ -182,14 +128,7 @@ int main(int argc, char** argv) {
|
|||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::BenchmarkArgs benchmark_args(argc, argv);
|
||||
|
||||
if (!benchmark_args.goldens.Empty()) {
|
||||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" +
|
||||
gcpp::ModelString(env.GetGemma()->Info().model,
|
||||
env.GetGemma()->Info().wrapping) +
|
||||
".txt";
|
||||
return BenchmarkGoldens(env, golden_path);
|
||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||
if (!benchmark_args.summarize_text.Empty()) {
|
||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||
} else if (!benchmark_args.cross_entropy.Empty()) {
|
||||
return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,
|
||||
|
|
|
|||
|
|
@ -42,39 +42,33 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
gen.seed(0x12345678);
|
||||
} else {
|
||||
// Depending on the library implementation, this may still be deterministic.
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
gen.seed(rd());
|
||||
}
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(const ThreadingArgs& threading_args,
|
||||
const LoaderArgs& loader, const InferenceArgs& inference)
|
||||
: env_(MakeMatMulEnv(threading_args)) {
|
||||
InferenceArgs mutable_inference = inference;
|
||||
AbortIfInvalidArgs(mutable_inference);
|
||||
LoaderArgs mutable_loader = loader;
|
||||
if (const char* err = mutable_loader.Validate()) {
|
||||
mutable_loader.Help();
|
||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||
} else {
|
||||
fprintf(stderr, "Loading model...\n");
|
||||
gemma_ = AllocateGemma(mutable_loader, env_);
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] = KVCache::Create(gemma_->GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
}
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader,
|
||||
const ThreadingArgs& threading_args,
|
||||
const InferenceArgs& inference)
|
||||
: env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) {
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] =
|
||||
KVCache::Create(gemma_.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
InitGenerator(inference, gen_);
|
||||
|
||||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.gen = &gen_,
|
||||
.verbosity = inference.verbosity,
|
||||
};
|
||||
inference.CopyTo(runtime_config_);
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(ThreadingArgs(argc, argv), LoaderArgs(argc, argv),
|
||||
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
|
||||
InferenceArgs(argc, argv)) {}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
|
|
@ -97,8 +91,8 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
|||
}
|
||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -107,8 +101,8 @@ void GemmaEnv::QueryModel(
|
|||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||
runtime_config_.stream_token = stream_token;
|
||||
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
runtime_config_.stream_token = previous_stream_token;
|
||||
}
|
||||
|
||||
|
|
@ -121,8 +115,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
size_t query_index, size_t pos,
|
||||
int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
gemma_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res[query_index].response.append(token_text);
|
||||
res[query_index].tokens_generated += 1;
|
||||
if (res[query_index].tokens_generated ==
|
||||
|
|
@ -144,7 +137,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
}
|
||||
for (size_t i = 1; i < num_queries; ++i) {
|
||||
if (kv_caches_[i].seq_len == 0) {
|
||||
kv_caches_[i] = KVCache::Create(gemma_->GetModelConfig(),
|
||||
kv_caches_[i] = KVCache::Create(gemma_.GetModelConfig(),
|
||||
runtime_config_.prefill_tbatch_size);
|
||||
}
|
||||
}
|
||||
|
|
@ -152,9 +145,9 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
std::vector<size_t> queries_pos(num_queries, 0);
|
||||
gemma_->GenerateBatch(runtime_config_, queries_prompt,
|
||||
QueriesPos(queries_pos.data(), num_queries),
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
gemma_.GenerateBatch(runtime_config_, queries_prompt,
|
||||
QueriesPos(queries_pos.data(), num_queries),
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
@ -234,11 +227,13 @@ static constexpr const char* CompiledConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference) {
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config) {
|
||||
threading.Print(inference.verbosity);
|
||||
loader.Print(inference.verbosity);
|
||||
inference.Print(inference.verbosity);
|
||||
fprintf(stderr, "Model : %s, mmap %d\n",
|
||||
config.Specifier().c_str(), static_cast<int>(loader.map));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
|
|
@ -249,38 +244,32 @@ void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
|||
|
||||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
"CPU : %s\n"
|
||||
"CPU : %s, bind %d\n"
|
||||
"CPU topology : %s, %s, %s\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Memory MiB : %4zu, %4zu free\n"
|
||||
"Weight Type : %s\n",
|
||||
dt, cpu100, ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
"Memory MiB : %4zu, %4zu free\n",
|
||||
dt, cpu100, static_cast<int>(threading.bind),
|
||||
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
|
||||
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB(),
|
||||
StringFromType(loader.Info().weight));
|
||||
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB());
|
||||
}
|
||||
}
|
||||
|
||||
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference) {
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
std::cerr
|
||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"To run gemma.cpp, you need to "
|
||||
"specify 3 required model loading arguments:\n"
|
||||
" --tokenizer\n"
|
||||
" --weights\n"
|
||||
" --model,\n"
|
||||
" or with the single-file weights format, specify just:\n"
|
||||
" --weights\n";
|
||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||
"With the single-file weights format, specify just --weights.\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
"--weights gemma2-2b-it-sfp.sbs\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n";
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
|
|
@ -47,7 +47,7 @@ class GemmaEnv {
|
|||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
|
||||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
// Avoid memory leaks in test.
|
||||
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
|
||||
|
|
@ -64,7 +64,7 @@ class GemmaEnv {
|
|||
|
||||
std::vector<int> Tokenize(const std::string& input) const {
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(gemma_->Tokenizer().Encode(input, &tokens));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Encode(input, &tokens));
|
||||
return tokens;
|
||||
}
|
||||
|
||||
|
|
@ -75,13 +75,13 @@ class GemmaEnv {
|
|||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||
return gcpp::WrapAndTokenize(gemma_->Tokenizer(), gemma_->ChatTemplate(),
|
||||
gemma_->Info(), 0, input);
|
||||
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||
gemma_.GetModelConfig().wrapping, 0, input);
|
||||
}
|
||||
|
||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||
std::string string;
|
||||
HWY_ASSERT(gemma_->Tokenizer().Decode(tokens, &string));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode(tokens, &string));
|
||||
return string;
|
||||
}
|
||||
|
||||
|
|
@ -104,8 +104,7 @@ class GemmaEnv {
|
|||
// number of bits per token.
|
||||
float CrossEntropy(const std::string& input);
|
||||
|
||||
// Returns nullptr if the model failed to load.
|
||||
Gemma* GetGemma() const { return gemma_.get(); }
|
||||
const Gemma* GetGemma() const { return &gemma_; }
|
||||
|
||||
int Verbosity() const { return runtime_config_.verbosity; }
|
||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
|
|
@ -114,8 +113,8 @@ class GemmaEnv {
|
|||
|
||||
private:
|
||||
MatMulEnv env_;
|
||||
std::mt19937 gen_; // Random number generator.
|
||||
std::unique_ptr<Gemma> gemma_;
|
||||
Gemma gemma_;
|
||||
std::mt19937 gen_; // Random number generator.
|
||||
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
|
@ -123,10 +122,10 @@ class GemmaEnv {
|
|||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference);
|
||||
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference);
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config);
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -44,10 +44,6 @@
|
|||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
template <typename TConfig>
|
||||
struct GetVocabSize {
|
||||
int operator()() const { return TConfig::kVocabSize; }
|
||||
};
|
||||
|
||||
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
||||
std::string token_str;
|
||||
|
|
@ -96,7 +92,7 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(CallSoftmax);
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity) {
|
||||
const StreamFunc stream_token = [](int, float) { return true; };
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity);
|
||||
|
||||
|
|
|
|||
|
|
@ -46,13 +46,14 @@ class GemmaTest : public ::testing::Test {
|
|||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
std::vector<std::string> replies;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
if (config.model == Model::GEMMA2_27B ||
|
||||
config.model == Model::GRIFFIN_2B) {
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -36,22 +36,30 @@
|
|||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Shared state. Requires argc/argv, so construct in main and use the same raw
|
||||
// pointer approach as in benchmarks.cc. Note that the style guide forbids
|
||||
// non-local static variables with dtors.
|
||||
GemmaEnv* s_env = nullptr;
|
||||
|
||||
class GemmaTest : public ::testing::Test {
|
||||
public:
|
||||
// Requires argc/argv, hence do not use `SetUpTestSuite`.
|
||||
static void InitEnv(int argc, char** argv) {
|
||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||
s_env = new GemmaEnv(argc, argv);
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
fprintf(stderr, "Using %s)\n", config.Specifier().c_str());
|
||||
}
|
||||
|
||||
static void DeleteEnv() { delete s_env; }
|
||||
|
||||
protected:
|
||||
std::string GemmaReply(const std::string& prompt) {
|
||||
HWY_ASSERT(s_env); // must have called InitEnv()
|
||||
s_env->SetMaxGeneratedTokens(2048);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||
if (config.model == Model::GEMMA2_27B ||
|
||||
config.model == Model::GRIFFIN_2B) {
|
||||
std::string mutable_prompt = prompt;
|
||||
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return result.response;
|
||||
|
|
@ -64,15 +72,17 @@ class GemmaTest : public ::testing::Test {
|
|||
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
HWY_ASSERT(s_env); // must have called InitEnv()
|
||||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
std::vector<std::string> replies;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||
if (config.model == Model::GEMMA2_27B ||
|
||||
config.model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
|
|
@ -118,8 +128,14 @@ class GemmaTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shared state. Requires argc/argv, so construct in main via InitEnv.
|
||||
// Note that the style guide forbids non-local static variables with dtors.
|
||||
static GemmaEnv* s_env;
|
||||
};
|
||||
|
||||
GemmaEnv* GemmaTest::s_env = nullptr;
|
||||
|
||||
TEST_F(GemmaTest, GeographyBatched) {
|
||||
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||
// 6 are enough to test batching and the loop.
|
||||
|
|
@ -155,7 +171,8 @@ TEST_F(GemmaTest, Arithmetic) {
|
|||
}
|
||||
|
||||
TEST_F(GemmaTest, Multiturn) {
|
||||
Gemma* model = s_env->GetGemma();
|
||||
const Gemma* model = s_env->GetGemma();
|
||||
const ModelConfig& config = model->GetModelConfig();
|
||||
HWY_ASSERT(model != nullptr);
|
||||
size_t abs_pos = 0;
|
||||
std::string response;
|
||||
|
|
@ -179,8 +196,8 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
// First "say" something slightly unusual.
|
||||
std::string mutable_prompt = "I have a car and its color is turquoise.";
|
||||
std::vector<int> tokens =
|
||||
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), model->Info(),
|
||||
abs_pos, mutable_prompt);
|
||||
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||
config.wrapping, abs_pos, mutable_prompt);
|
||||
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
|
|
@ -189,7 +206,7 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
// duplicated.
|
||||
mutable_prompt = "Please repeat all prior statements.";
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||
model->Info(), abs_pos, mutable_prompt);
|
||||
config.wrapping, abs_pos, mutable_prompt);
|
||||
|
||||
// Reset the `response` string here, then check that the model actually has
|
||||
// access to the previous turn by asking to reproduce.
|
||||
|
|
@ -240,11 +257,12 @@ static const char kGettysburg[] = {
|
|||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 2.6f, 0.2f);
|
||||
|
|
@ -273,9 +291,10 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.9f, 0.2f);
|
||||
|
|
@ -304,9 +323,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
|||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.1f, 0.1f);
|
||||
|
|
@ -337,10 +357,9 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
|||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
gcpp::GemmaTest::InitEnv(argc, argv);
|
||||
int ret = RUN_ALL_TESTS();
|
||||
gcpp::GemmaTest::DeleteEnv();
|
||||
return ret;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@
|
|||
#include "gemma/gemma.h" // Gemma
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
|
|
|||
|
|
@ -31,15 +31,12 @@
|
|||
#include "hwy/base.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
loader.Help();
|
||||
return 0;
|
||||
} else if (const char* error = loader.Validate()) {
|
||||
loader.Help();
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||
|
|
@ -55,32 +52,31 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(model.GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
gcpp::Gemma gemma(loader, env);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(gemma.GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
std::mt19937 gen;
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
gen.seed(rd());
|
||||
|
||||
// Tokenize instructions.
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
const std::vector<int> tokens =
|
||||
gcpp::WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
loader.Info(), generated, prompt);
|
||||
gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
gemma.GetModelConfig().wrapping, generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
auto stream_token = [&generated, &prompt_size, &model](int token, float) {
|
||||
auto stream_token = [&generated, &prompt_size, &gemma](int token, float) {
|
||||
++generated;
|
||||
if (generated < prompt_size) {
|
||||
// print feedback
|
||||
} else if (!model.GetModelConfig().IsEOS(token)) {
|
||||
} else if (!gemma.GetModelConfig().IsEOS(token)) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text));
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -98,5 +94,5 @@ int main(int argc, char** argv) {
|
|||
return !reject_tokens.contains(token);
|
||||
},
|
||||
};
|
||||
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
||||
gemma.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,9 +39,9 @@ class SimplifiedGemma {
|
|||
threading_(threading),
|
||||
inference_(inference),
|
||||
env_(MakeMatMulEnv(threading_)),
|
||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
||||
gemma_(loader_, env_) {
|
||||
// Instantiate model and KV Cache
|
||||
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
|
||||
kv_cache_ = gcpp::KVCache::Create(gemma_.GetModelConfig(),
|
||||
inference_.prefill_tbatch_size);
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
@ -50,7 +50,7 @@ class SimplifiedGemma {
|
|||
}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv, /*validate=*/true),
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
||||
gcpp::ThreadingArgs(argc, argv),
|
||||
gcpp::InferenceArgs(argc, argv)) {}
|
||||
|
||||
|
|
@ -60,8 +60,8 @@ class SimplifiedGemma {
|
|||
size_t generated = 0;
|
||||
|
||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||
model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(),
|
||||
generated, prompt);
|
||||
gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||
gemma_.GetModelConfig().wrapping, generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
|
|
@ -69,9 +69,9 @@ class SimplifiedGemma {
|
|||
++generated;
|
||||
if (generated < prompt_size) {
|
||||
// print feedback
|
||||
} else if (!this->model_.GetModelConfig().IsEOS(token)) {
|
||||
} else if (!gemma_.GetModelConfig().IsEOS(token)) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -89,7 +89,7 @@ class SimplifiedGemma {
|
|||
return !reject_tokens.contains(token);
|
||||
},
|
||||
};
|
||||
model_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
|
||||
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
|
||||
}
|
||||
~SimplifiedGemma() = default;
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ class SimplifiedGemma {
|
|||
gcpp::ThreadingArgs threading_;
|
||||
gcpp::InferenceArgs inference_;
|
||||
gcpp::MatMulEnv env_;
|
||||
gcpp::Gemma model_;
|
||||
gcpp::Gemma gemma_;
|
||||
gcpp::KVCache kv_cache_;
|
||||
std::mt19937 gen_;
|
||||
std::string validation_error_;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
int main(int argc, char** argv) {
|
||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||
// necessary flags.
|
||||
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
|
||||
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -15,10 +15,12 @@
|
|||
|
||||
#include "gemma/bindings/context.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <stddef.h>
|
||||
#include <string.h> // strncpy
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h" // InitGenerator
|
||||
|
|
@ -51,33 +53,22 @@ GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
|||
void* GemmaContext::s_log_user_data = nullptr;
|
||||
|
||||
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||
const char* model_type,
|
||||
const char* ignored1,
|
||||
const char* weights_path,
|
||||
const char* weight_type, int max_length) {
|
||||
const char* ignored2, int max_length) {
|
||||
std::stringstream ss;
|
||||
ss << "Creating GemmaContext with tokenizer_path: "
|
||||
<< (tokenizer_path ? tokenizer_path : "null")
|
||||
<< ", model_type: " << (model_type ? model_type : "null")
|
||||
<< ", weights_path: " << (weights_path ? weights_path : "null")
|
||||
<< ", weight_type: " << (weight_type ? weight_type : "null")
|
||||
<< ", max_length: " << max_length;
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
|
||||
LoaderArgs loader(tokenizer_path, weights_path, model_type);
|
||||
loader.weight_type_str = weight_type;
|
||||
LoaderArgs loader(tokenizer_path, weights_path);
|
||||
LogDebug("LoaderArgs created");
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
ss.str("");
|
||||
ss << "Invalid loader configuration: " << error;
|
||||
LogDebug(ss.str().c_str());
|
||||
HWY_ABORT("Invalid loader configuration: %s", error);
|
||||
}
|
||||
LogDebug("Loader validated successfully");
|
||||
|
||||
// Initialize cached args
|
||||
LogDebug("Initializing inference args");
|
||||
InferenceArgs inference_args;
|
||||
|
|
@ -103,7 +94,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
|||
: inference_args(inference_args),
|
||||
threading_args(threading_args),
|
||||
matmul_env(MakeMatMulEnv(threading_args)),
|
||||
model(CreateGemma(loader, matmul_env)) {
|
||||
model(loader, matmul_env) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
|
|
@ -186,8 +177,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
Extents2D(model.GetModelConfig().vit_config.seq_len /
|
||||
(pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
|
||||
Image image;
|
||||
image.Set(image_width, image_height, static_cast<const float*>(image_data));
|
||||
|
|
@ -210,8 +201,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
LogDebug(ss.str().c_str());
|
||||
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.Info(), active_conversation->abs_pos,
|
||||
prompt_string, image_tokens.BatchSize());
|
||||
model.GetModelConfig().wrapping,
|
||||
active_conversation->abs_pos, prompt_string,
|
||||
image_tokens.BatchSize());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
|
|
@ -220,9 +212,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
} else {
|
||||
// Text-only case (original logic)
|
||||
// Use abs_pos from the active conversation
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
active_conversation->abs_pos, prompt_string);
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.GetModelConfig().wrapping,
|
||||
active_conversation->abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
|
|
@ -238,7 +230,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
active_conversation->abs_pos = 0;
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ class GemmaContext {
|
|||
const ThreadingArgs& threading_args, int max_length);
|
||||
|
||||
public:
|
||||
static GemmaContext* Create(const char* tokenizer_path,
|
||||
const char* model_type, const char* weights_path,
|
||||
const char* weight_type, int max_length);
|
||||
static GemmaContext* Create(const char* tokenizer_path, const char* ignored1,
|
||||
const char* weights_path, const char* ignored2,
|
||||
int max_length);
|
||||
|
||||
// Returns length of generated text, or -1 on error
|
||||
int Generate(const char* prompt_string, char* output, int max_length,
|
||||
|
|
|
|||
142
gemma/common.cc
142
gemma/common.cc
|
|
@ -17,142 +17,20 @@
|
|||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm> // std::transform
|
||||
#include <cctype>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "util/basics.h" // BF16
|
||||
// TODO: change include when PromptWrapping is moved.
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
constexpr const char* kModelFlags[] = {
|
||||
"2b-pt", "2b-it", // Gemma 2B
|
||||
"7b-pt", "7b-it", // Gemma 7B
|
||||
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
||||
"tiny", // Gemma Tiny (mostly for debugging)
|
||||
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
|
||||
"9b-pt", "9b-it", // Gemma2 9B
|
||||
"27b-pt", "27b-it", // Gemma2 27B
|
||||
"paligemma-224", // PaliGemma 224
|
||||
"paligemma-448", // PaliGemma 448
|
||||
"paligemma2-3b-224", // PaliGemma2 3B 224
|
||||
"paligemma2-3b-448", // PaliGemma2 3B 448
|
||||
"paligemma2-10b-224", // PaliGemma2 10B 224
|
||||
"paligemma2-10b-448", // PaliGemma2 10B 448
|
||||
"gemma3-4b", // Gemma3 4B
|
||||
"gemma3-1b", // Gemma3 1B
|
||||
"gemma3-12b", // Gemma3 12B
|
||||
"gemma3-27b", // Gemma3 27B
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
|
||||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||
Model::GEMMA_TINY, // Gemma Tiny
|
||||
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
|
||||
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
|
||||
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
|
||||
Model::PALIGEMMA_224, // PaliGemma 224
|
||||
Model::PALIGEMMA_448, // PaliGemma 448
|
||||
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
|
||||
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
|
||||
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
|
||||
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
|
||||
Model::GEMMA3_4B, // Gemma3 4B
|
||||
Model::GEMMA3_1B, // Gemma3 1B
|
||||
Model::GEMMA3_12B, // Gemma3 12B
|
||||
Model::GEMMA3_27B, // Gemma3 27B
|
||||
};
|
||||
constexpr PromptWrapping kPromptWrapping[] = {
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
|
||||
PromptWrapping::GEMMA_IT, // Gemma Tiny
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 4B
|
||||
PromptWrapping::GEMMA_IT, // Gemma3 1B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 12B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 27B
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags = std::size(kModelFlags);
|
||||
static_assert(kNumModelFlags == std::size(kModelTypes));
|
||||
static_assert(kNumModelFlags == std::size(kPromptWrapping));
|
||||
|
||||
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
|
||||
Model& model, PromptWrapping& wrapping) {
|
||||
static std::string kErrorMessageBuffer =
|
||||
"Invalid or missing model flag, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
|
||||
kErrorMessageBuffer.append(kModelFlags[i]);
|
||||
kErrorMessageBuffer.append(", ");
|
||||
}
|
||||
kErrorMessageBuffer.append(kModelFlags[kNumModelFlags - 1]);
|
||||
kErrorMessageBuffer.append(".");
|
||||
std::string model_type_lc = model_flag;
|
||||
std::transform(model_type_lc.begin(), model_type_lc.end(),
|
||||
model_type_lc.begin(), ::tolower);
|
||||
for (size_t i = 0; i < kNumModelFlags; ++i) {
|
||||
if (kModelFlags[i] == model_type_lc) {
|
||||
model = kModelTypes[i];
|
||||
wrapping = kPromptWrapping[i];
|
||||
HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer.c_str();
|
||||
}
|
||||
|
||||
const char* ModelString(Model model, PromptWrapping wrapping) {
|
||||
for (size_t i = 0; i < kNumModelFlags; i++) {
|
||||
if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
|
||||
return kModelFlags[i];
|
||||
}
|
||||
HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
|
||||
static_cast<int>(wrapping));
|
||||
}
|
||||
|
||||
const char* StringFromType(Type type) {
|
||||
return kTypeStrings[static_cast<size_t>(type)];
|
||||
}
|
||||
|
||||
const char* ParseType(const std::string& type_string, Type& type) {
|
||||
constexpr size_t kNum = std::size(kTypeStrings);
|
||||
static std::string kErrorMessageBuffer =
|
||||
"Invalid or missing type, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNum; ++i) {
|
||||
kErrorMessageBuffer.append(kTypeStrings[i]);
|
||||
kErrorMessageBuffer.append(", ");
|
||||
}
|
||||
kErrorMessageBuffer.append(kTypeStrings[kNum - 1]);
|
||||
kErrorMessageBuffer.append(".");
|
||||
std::string type_lc = type_string;
|
||||
std::transform(type_lc.begin(), type_lc.end(), type_lc.begin(), ::tolower);
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
if (kTypeStrings[i] == type_lc) {
|
||||
type = static_cast<Type>(i);
|
||||
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer.c_str();
|
||||
}
|
||||
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt) {
|
||||
|
||||
// Instruction-tuned models are trained to expect control tokens.
|
||||
if (info.wrapping == PromptWrapping::GEMMA_IT) {
|
||||
if (config.wrapping == PromptWrapping::GEMMA_IT) {
|
||||
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
||||
const std::string start = (pos == 0)
|
||||
? "<start_of_turn>user\n"
|
||||
|
|
@ -175,4 +53,16 @@ float ChooseQueryScale(const ModelConfig& config) {
|
|||
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
|
||||
}
|
||||
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
||||
if (!weights_config.use_local_attention) {
|
||||
if (max_generated_tokens > weights_config.seq_len) {
|
||||
HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.",
|
||||
max_generated_tokens, weights_config.seq_len);
|
||||
max_generated_tokens = weights_config.seq_len;
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -20,39 +20,24 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Struct to bundle model information.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
PromptWrapping wrapping;
|
||||
Type weight;
|
||||
};
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
|
||||
Model& model, PromptWrapping& wrapping);
|
||||
const char* ParseType(const std::string& type_string, Type& type);
|
||||
|
||||
// Inverse of ParseModelTypeAndWrapping.
|
||||
const char* ModelString(Model model, PromptWrapping wrapping);
|
||||
const char* StringFromType(Type type);
|
||||
|
||||
// Wraps the given prompt using the expected control tokens for IT models.
|
||||
// `GemmaChatTemplate` is preferred if a tokenized return value is fine.
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
|
||||
// DEPRECATED, use WrapAndTokenize instead if a tokenized return value is fine.
|
||||
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt);
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
// Also used by backprop/.
|
||||
float EmbeddingScaling(size_t model_dim);
|
||||
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
float ChooseQueryScale(const ModelConfig& config);
|
||||
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, size_t prompt_size);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
|
|
|||
477
gemma/configs.cc
477
gemma/configs.cc
|
|
@ -15,17 +15,30 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/fields.h" // IFields
|
||||
#include "compression/shared.h" // Type
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
#endif // !GEMMA_MAX_SEQLEN
|
||||
|
||||
static constexpr size_t kVocabSize = 256000;
|
||||
|
||||
static ModelConfig ConfigNoSSM() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
|
||||
"gr_lin_y_w", "gr_lin_out_w", "gr_gate_w",
|
||||
"gating_ein", "linear_w"};
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -54,14 +67,14 @@ static LayerConfig LayerConfigGemma2_27B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2_27B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_27B";
|
||||
config.display_name = "Gemma2_27B";
|
||||
config.model = Model::GEMMA2_27B;
|
||||
config.model_dim = 4608;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
||||
config.layer_configs = {46, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 46;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<46, 2>({4096, 8192});
|
||||
|
|
@ -82,14 +95,14 @@ static LayerConfig LayerConfigGemma2_9B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2_9B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_9B";
|
||||
config.display_name = "Gemma2_9B";
|
||||
config.model = Model::GEMMA2_9B;
|
||||
config.model_dim = 3584;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
||||
config.layer_configs = {42, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 42;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<42, 2>({4096, 8192});
|
||||
|
|
@ -110,14 +123,14 @@ static LayerConfig LayerConfigGemma2_2B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2_2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_2B";
|
||||
config.display_name = "Gemma2_2B";
|
||||
config.model = Model::GEMMA2_2B;
|
||||
config.model_dim = 2304;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<26, 2>({4096, 8192});
|
||||
|
|
@ -136,16 +149,17 @@ static LayerConfig LayerConfigGemma7B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma7B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma7B";
|
||||
config.display_name = "Gemma7B";
|
||||
config.model = Model::GEMMA_7B;
|
||||
config.model_dim = 3072;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
config.seq_len = GEMMA_MAX_SEQLEN;
|
||||
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
|
||||
config.layer_configs = {28, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 28;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<28>(GEMMA_MAX_SEQLEN);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -161,15 +175,16 @@ static LayerConfig LayerConfigGemma2B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma2B";
|
||||
config.display_name = "Gemma2B";
|
||||
config.model = Model::GEMMA_2B;
|
||||
config.model_dim = 2048;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
config.seq_len = GEMMA_MAX_SEQLEN;
|
||||
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
|
||||
config.layer_configs = {18, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
|
||||
config.num_layers = 18;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<18>(GEMMA_MAX_SEQLEN);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -185,18 +200,19 @@ static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemmaTiny() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "GemmaTiny";
|
||||
config.display_name = "GemmaTiny";
|
||||
config.model = Model::GEMMA_TINY;
|
||||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 128;
|
||||
config.vocab_size = 64;
|
||||
config.seq_len = 32;
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 16;
|
||||
config.seq_len = 32; // optimize_test requires more than 24
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.layer_configs = {3, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 2;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
config.eos_id = 11;
|
||||
config.secondary_eos_id = 11;
|
||||
|
|
@ -224,20 +240,20 @@ static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGriffin2B() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "Griffin2B";
|
||||
config.display_name = "Griffin2B";
|
||||
config.model = Model::GRIFFIN_2B;
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
// Griffin uses local attention, so GEMMA_MAX_SEQLEN is actually the local
|
||||
// attention window.
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 2048;
|
||||
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
for (size_t i = 2; i < config.num_layers; i += 3) {
|
||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||
config.layer_configs[i].griffin_dim = 0;
|
||||
}
|
||||
config.num_tensor_scales = 140;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
|
||||
config.use_local_attention = true;
|
||||
// This is required for optimize_test to pass.
|
||||
|
|
@ -276,7 +292,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
|||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.display_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
|
|
@ -285,7 +301,7 @@ static ModelConfig ConfigPaliGemma_224() {
|
|||
|
||||
static ModelConfig ConfigPaliGemma_448() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_448";
|
||||
config.display_name = "PaliGemma_448";
|
||||
config.model = Model::PALIGEMMA_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
|
|
@ -306,7 +322,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
|
|||
|
||||
static ModelConfig ConfigPaliGemma2_3B_224() {
|
||||
ModelConfig config = ConfigGemma2_2B();
|
||||
config.model_name = "PaliGemma2_3B_224";
|
||||
config.display_name = "PaliGemma2_3B_224";
|
||||
config.model = Model::PALIGEMMA2_3B_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
|
|
@ -315,7 +331,7 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
|
|||
|
||||
static ModelConfig ConfigPaliGemma2_3B_448() {
|
||||
ModelConfig config = ConfigGemma2_2B();
|
||||
config.model_name = "PaliGemma2_3B_448";
|
||||
config.display_name = "PaliGemma2_3B_448";
|
||||
config.model = Model::PALIGEMMA2_3B_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
|
|
@ -324,7 +340,7 @@ static ModelConfig ConfigPaliGemma2_3B_448() {
|
|||
|
||||
static ModelConfig ConfigPaliGemma2_10B_224() {
|
||||
ModelConfig config = ConfigGemma2_9B();
|
||||
config.model_name = "PaliGemma2_10B_224";
|
||||
config.display_name = "PaliGemma2_10B_224";
|
||||
config.model = Model::PALIGEMMA2_10B_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
|
|
@ -333,7 +349,7 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
|
|||
|
||||
static ModelConfig ConfigPaliGemma2_10B_448() {
|
||||
ModelConfig config = ConfigGemma2_9B();
|
||||
config.model_name = "PaliGemma2_10B_448";
|
||||
config.display_name = "PaliGemma2_10B_448";
|
||||
config.model = Model::PALIGEMMA2_10B_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
|
|
@ -365,15 +381,15 @@ static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma3_1B() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_1B";
|
||||
config.display_name = "Gemma3_1B";
|
||||
config.model = Model::GEMMA3_1B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 1152;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
||||
|
|
@ -397,15 +413,15 @@ static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) {
|
|||
// Until we have the SigLIP checkpoints included, we use the LM config directly.
|
||||
static ModelConfig ConfigGemma3_4B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.display_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
||||
config.layer_configs = {34, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 34;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
||||
|
|
@ -415,7 +431,7 @@ static ModelConfig ConfigGemma3_4B_LM() {
|
|||
|
||||
static ModelConfig ConfigGemma3_4B() {
|
||||
ModelConfig config = ConfigGemma3_4B_LM();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.display_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
|
|
@ -446,15 +462,15 @@ static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma3_12B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.display_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 3840;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
||||
config.layer_configs = {48, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 48;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
||||
|
|
@ -464,7 +480,7 @@ static ModelConfig ConfigGemma3_12B_LM() {
|
|||
|
||||
static ModelConfig ConfigGemma3_12B() {
|
||||
ModelConfig config = ConfigGemma3_12B_LM();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.display_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
|
|
@ -495,15 +511,15 @@ static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma3_27B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.display_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 5376;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
||||
config.layer_configs = {62, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 62;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
||||
|
|
@ -513,7 +529,7 @@ static ModelConfig ConfigGemma3_27B_LM() {
|
|||
|
||||
static ModelConfig ConfigGemma3_27B() {
|
||||
ModelConfig config = ConfigGemma3_27B_LM();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.display_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
|
|
@ -529,7 +545,7 @@ static ModelConfig ConfigGemma3_27B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
ModelConfig ConfigFromModel(Model model) {
|
||||
static ModelConfig ConfigFromModel(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return ConfigGemma2B();
|
||||
|
|
@ -570,124 +586,259 @@ ModelConfig ConfigFromModel(Model model) {
|
|||
}
|
||||
}
|
||||
|
||||
#define TEST_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
if (debug) \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
result = false; \
|
||||
const char* ModelPrefix(Model model) {
|
||||
switch (model) {
|
||||
case Model::UNKNOWN:
|
||||
return "unknown";
|
||||
case Model::GEMMA_2B:
|
||||
return "2b";
|
||||
case Model::GEMMA_7B:
|
||||
return "7b";
|
||||
case Model::GEMMA2_2B:
|
||||
return "gemma2-2b";
|
||||
case Model::GEMMA2_9B:
|
||||
return "9b";
|
||||
case Model::GEMMA2_27B:
|
||||
return "27b";
|
||||
case Model::GRIFFIN_2B:
|
||||
return "gr2b";
|
||||
case Model::GEMMA_TINY:
|
||||
return "tiny";
|
||||
case Model::PALIGEMMA_224:
|
||||
return "paligemma-224";
|
||||
case Model::PALIGEMMA_448:
|
||||
return "paligemma-448";
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return "paligemma2-3b-224";
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
return "paligemma2-3b-448";
|
||||
case Model::PALIGEMMA2_10B_224:
|
||||
return "paligemma2-10b-224";
|
||||
case Model::PALIGEMMA2_10B_448:
|
||||
return "paligemma2-10b-448";
|
||||
case Model::GEMMA3_4B:
|
||||
return "gemma3-4b";
|
||||
case Model::GEMMA3_1B:
|
||||
return "gemma3-1b";
|
||||
case Model::GEMMA3_12B:
|
||||
return "gemma3-12b";
|
||||
case Model::GEMMA3_27B:
|
||||
return "gemma3-27b";
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
||||
#define RETURN_IF_NOT_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
if (debug) \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
return false; \
|
||||
}
|
||||
|
||||
#define WARN_IF_NOT_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
}
|
||||
|
||||
bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
// Optimized gating may not be set correctly in the c++ configs.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating)
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(griffin_dim, other.griffin_dim);
|
||||
TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim);
|
||||
TEST_EQUAL(heads, other.heads);
|
||||
TEST_EQUAL(kv_heads, other.kv_heads);
|
||||
TEST_EQUAL(qkv_dim, other.qkv_dim);
|
||||
TEST_EQUAL(conv1d_width, other.conv1d_width);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(ff_biases, other.ff_biases);
|
||||
TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases);
|
||||
}
|
||||
TEST_EQUAL(static_cast<int>(post_norm), static_cast<int>(other.post_norm));
|
||||
TEST_EQUAL(static_cast<int>(type), static_cast<int>(other.type));
|
||||
TEST_EQUAL(static_cast<int>(activation), static_cast<int>(other.activation));
|
||||
TEST_EQUAL(static_cast<int>(post_qk), static_cast<int>(other.post_qk));
|
||||
return result;
|
||||
}
|
||||
|
||||
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(seq_len, other.seq_len);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(num_scales, other.num_scales);
|
||||
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
|
||||
if (IsPaliGemma(model)) {
|
||||
if (wrapping != Tristate::kDefault) {
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models.");
|
||||
}
|
||||
return PromptWrapping::PALIGEMMA;
|
||||
}
|
||||
TEST_EQUAL(patch_width, other.patch_width);
|
||||
TEST_EQUAL(image_size, other.image_size);
|
||||
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
|
||||
for (size_t i = 0; i < layer_configs.size(); ++i) {
|
||||
result &=
|
||||
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
|
||||
if (IsVLM(model)) {
|
||||
if (wrapping != Tristate::kDefault) {
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
|
||||
}
|
||||
return PromptWrapping::GEMMA_VLM;
|
||||
}
|
||||
return result;
|
||||
// Default to IT unless --wrapping=0.
|
||||
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
|
||||
: PromptWrapping::GEMMA_IT;
|
||||
}
|
||||
|
||||
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
TEST_EQUAL(model_family_version, other.model_family_version);
|
||||
// We don't care about model_name, model, wrapping, or weight being different,
|
||||
// but will output in debug mode if they are.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(model_name, other.model_name);
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
|
||||
static_cast<int>(other.wrapping));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(vocab_size, other.vocab_size);
|
||||
TEST_EQUAL(seq_len, other.seq_len);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
|
||||
}
|
||||
TEST_EQUAL(att_cap, other.att_cap);
|
||||
TEST_EQUAL(final_cap, other.final_cap);
|
||||
TEST_EQUAL(absolute_pe, other.absolute_pe);
|
||||
TEST_EQUAL(use_local_attention, other.use_local_attention);
|
||||
TEST_EQUAL(static_cast<int>(query_scale),
|
||||
static_cast<int>(other.query_scale));
|
||||
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
|
||||
for (size_t i = 0; i < layer_configs.size(); ++i) {
|
||||
result &=
|
||||
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
|
||||
}
|
||||
RETURN_IF_NOT_EQUAL(attention_window_sizes.size(),
|
||||
other.attention_window_sizes.size());
|
||||
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
||||
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
|
||||
}
|
||||
if (!partial) {
|
||||
if (scale_names != other.scale_names) {
|
||||
result = false;
|
||||
if (debug) {
|
||||
std::cerr << "scale_names mismatch\n";
|
||||
}
|
||||
ModelConfig::ModelConfig(const Model model, Type weight,
|
||||
PromptWrapping wrapping) {
|
||||
HWY_ASSERT(weight != Type::kUnknown);
|
||||
HWY_ASSERT(wrapping != PromptWrapping::kSentinel);
|
||||
this->model = model;
|
||||
if (model != Model::UNKNOWN) *this = ConfigFromModel(model);
|
||||
HWY_ASSERT(this->model == model);
|
||||
this->weight = weight;
|
||||
this->wrapping = wrapping;
|
||||
}
|
||||
|
||||
static Model FindModel(const std::string& specifier) {
|
||||
Model found_model = Model::UNKNOWN;
|
||||
ForEachModel([&](Model model) {
|
||||
const char* prefix = ModelPrefix(model);
|
||||
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
|
||||
// We only expect one match.
|
||||
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
|
||||
found_model = model;
|
||||
}
|
||||
});
|
||||
HWY_ASSERT_M(found_model != Model::UNKNOWN, specifier.c_str());
|
||||
return found_model;
|
||||
}
|
||||
|
||||
static Type FindType(const std::string& specifier) {
|
||||
Type found_type = Type::kUnknown;
|
||||
for (size_t i = 1; i < kNumTypes; ++i) {
|
||||
const Type type = static_cast<Type>(i);
|
||||
if (specifier.find(TypeName(type)) != std::string::npos) { // NOLINT
|
||||
// We only expect one match.
|
||||
HWY_ASSERT_M(found_type == Type::kUnknown, specifier.c_str());
|
||||
found_type = type;
|
||||
}
|
||||
}
|
||||
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
|
||||
result &= vit_config.TestEqual(other.vit_config, partial, debug);
|
||||
return result;
|
||||
HWY_ASSERT_M(found_type != Type::kUnknown, specifier.c_str());
|
||||
return found_type;
|
||||
}
|
||||
|
||||
Model ModelFromConfig(const ModelConfig& config) {
|
||||
for (Model model : kAllModels) {
|
||||
ModelConfig model_config = ConfigFromModel(model);
|
||||
if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) {
|
||||
return model;
|
||||
static PromptWrapping FindWrapping(const std::string& specifier) {
|
||||
PromptWrapping found_wrapping = PromptWrapping::kSentinel;
|
||||
for (size_t i = 0; i < static_cast<size_t>(PromptWrapping::kSentinel); ++i) {
|
||||
const PromptWrapping w = static_cast<PromptWrapping>(i);
|
||||
if (specifier.find(WrappingSuffix(w)) != std::string::npos) { // NOLINT
|
||||
// We expect zero or one match.
|
||||
HWY_ASSERT_M(found_wrapping == PromptWrapping::kSentinel,
|
||||
specifier.c_str());
|
||||
found_wrapping = w;
|
||||
}
|
||||
}
|
||||
return Model::UNKNOWN;
|
||||
if (found_wrapping == PromptWrapping::kSentinel) {
|
||||
return ChooseWrapping(FindModel(specifier));
|
||||
}
|
||||
return found_wrapping;
|
||||
}
|
||||
|
||||
// Obtains model/weight/wrapping by finding prefix and suffix strings.
|
||||
ModelConfig::ModelConfig(const std::string& specifier)
|
||||
: ModelConfig(FindModel(specifier), FindType(specifier),
|
||||
FindWrapping(specifier)) {}
|
||||
|
||||
std::string ModelConfig::Specifier() const {
|
||||
HWY_ASSERT(model != Model::UNKNOWN);
|
||||
HWY_ASSERT(weight != Type::kUnknown);
|
||||
HWY_ASSERT(wrapping != PromptWrapping::kSentinel);
|
||||
|
||||
std::string base_name = ModelPrefix(model);
|
||||
|
||||
base_name += '-';
|
||||
base_name += TypeName(weight);
|
||||
|
||||
if (wrapping != PromptWrapping::GEMMA_VLM &&
|
||||
wrapping != PromptWrapping::PALIGEMMA) {
|
||||
base_name += WrappingSuffix(wrapping);
|
||||
}
|
||||
|
||||
return base_name;
|
||||
}
|
||||
|
||||
// Returns whether all fields match.
|
||||
static bool AllEqual(const IFields& a, const IFields& b, bool print) {
|
||||
const std::vector<uint32_t> serialized_a = a.Write();
|
||||
const std::vector<uint32_t> serialized_b = b.Write();
|
||||
if (serialized_a != serialized_b) {
|
||||
if (print) {
|
||||
fprintf(stderr, "%s differs. Recommend generating a diff:\n", a.Name());
|
||||
a.Print();
|
||||
b.Print();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LayerConfig::TestEqual(const LayerConfig& other, bool print) const {
|
||||
return AllEqual(*this, other, print);
|
||||
}
|
||||
|
||||
bool VitConfig::TestEqual(const VitConfig& other, bool print) const {
|
||||
return AllEqual(*this, other, print);
|
||||
}
|
||||
|
||||
bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const {
|
||||
// Early out to guard the loop below; a differing number of layers will anyway
|
||||
// cause a mismatch.
|
||||
if (layer_configs.size() != other.layer_configs.size()) {
|
||||
if (print) {
|
||||
HWY_WARN("Layer configs size mismatch %zu vs %zu", layer_configs.size(),
|
||||
other.layer_configs.size());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy so we can 'ignore' fields by setting them to the same value.
|
||||
ModelConfig a = *this;
|
||||
ModelConfig b = other;
|
||||
// Called by `OverwriteWithCanonical`, so ignore the fields it will set.
|
||||
a.display_name = b.display_name;
|
||||
a.model = b.model;
|
||||
|
||||
// The following are not yet set by config_converter.py, so we here ignore
|
||||
// them for purposes of comparison, and there overwrite the converter's config
|
||||
// with the canonical ModelConfig constructed via (deduced) enum, so that
|
||||
// these fields will be set.
|
||||
// `vit_config` is also not yet set, but we must not ignore it because
|
||||
// otherwise PaliGemma models will be indistinguishable for `configs_test`.
|
||||
a.pool_dim = b.pool_dim; // ViT
|
||||
a.eos_id = b.eos_id;
|
||||
a.secondary_eos_id = b.secondary_eos_id;
|
||||
a.scale_base_names = b.scale_base_names;
|
||||
for (size_t i = 0; i < a.layer_configs.size(); ++i) {
|
||||
a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating;
|
||||
}
|
||||
|
||||
return AllEqual(a, b, print);
|
||||
}
|
||||
|
||||
// Constructs the canonical ModelConfig for each model. If there is one for
|
||||
// which TestEqual returns true, overwrites `*this` with that and returns true.
|
||||
bool ModelConfig::OverwriteWithCanonical() {
|
||||
bool found = false;
|
||||
const bool print = false;
|
||||
ForEachModel([&](Model model) {
|
||||
const ModelConfig config(model, weight, wrapping);
|
||||
if (config.TestEqual(*this, print)) {
|
||||
HWY_ASSERT(!found); // Should only find one.
|
||||
found = true;
|
||||
*this = config;
|
||||
}
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
Model DeduceModel(size_t layers, int layer_types) {
|
||||
switch (layers) {
|
||||
case 3:
|
||||
return Model::GEMMA_TINY;
|
||||
case 18:
|
||||
return Model::GEMMA_2B;
|
||||
case 26:
|
||||
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
|
||||
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
|
||||
return Model::GEMMA2_2B;
|
||||
case 28:
|
||||
return Model::GEMMA_7B;
|
||||
case 34:
|
||||
return Model::GEMMA3_4B;
|
||||
case 42:
|
||||
return Model::GEMMA2_9B;
|
||||
case 46:
|
||||
return Model::GEMMA2_27B;
|
||||
case 48:
|
||||
return Model::GEMMA3_12B;
|
||||
case 62:
|
||||
return Model::GEMMA3_27B;
|
||||
|
||||
// TODO: detect these.
|
||||
/*
|
||||
return Model::GEMMA2_772M;
|
||||
return Model::PALIGEMMA2_772M_224;
|
||||
return Model::PALIGEMMA_224;
|
||||
return Model::PALIGEMMA_448;
|
||||
return Model::PALIGEMMA2_3B_224;
|
||||
return Model::PALIGEMMA2_3B_448;
|
||||
return Model::PALIGEMMA2_10B_224;
|
||||
return Model::PALIGEMMA2_10B_448;
|
||||
*/
|
||||
default:
|
||||
HWY_WARN("Failed to deduce model type from layer count %zu types %x.",
|
||||
layers, layer_types);
|
||||
return Model::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
284
gemma/configs.h
284
gemma/configs.h
|
|
@ -23,31 +23,16 @@
|
|||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/fields.h" // IFieldsVisitor
|
||||
#include "compression/shared.h" // BF16
|
||||
#include "compression/shared.h" // Type
|
||||
#include "util/basics.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
#endif // !GEMMA_MAX_SEQLEN
|
||||
|
||||
// Allow changing k parameter of `SampleTopK` as a compiler flag
|
||||
#ifndef GEMMA_TOPK
|
||||
#define GEMMA_TOPK 1
|
||||
#endif // !GEMMA_TOPK
|
||||
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
static constexpr size_t kVocabSize = 256000;
|
||||
static constexpr size_t kMaxConv1DWidth = 4;
|
||||
|
||||
using EmbedderInputT = BF16;
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class PromptWrapping {
|
||||
GEMMA_IT,
|
||||
|
|
@ -57,8 +42,9 @@ enum class PromptWrapping {
|
|||
kSentinel // must be last
|
||||
};
|
||||
|
||||
// Defined as the suffix for use with `ModelString`.
|
||||
static inline const char* ToString(PromptWrapping wrapping) {
|
||||
// This is used in `ModelConfig.Specifier`, so the strings will not change,
|
||||
// though new ones may be added.
|
||||
static inline const char* WrappingSuffix(PromptWrapping wrapping) {
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::GEMMA_IT:
|
||||
return "-it";
|
||||
|
|
@ -177,7 +163,7 @@ enum class Model {
|
|||
GEMMA2_9B,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA_TINY, // for backprop/ only
|
||||
GEMMA2_2B,
|
||||
PALIGEMMA_224,
|
||||
PALIGEMMA_448,
|
||||
|
|
@ -192,16 +178,28 @@ enum class Model {
|
|||
kSentinel,
|
||||
};
|
||||
|
||||
// Allows the Model enum to be iterated over.
|
||||
static constexpr Model kAllModels[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
|
||||
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
|
||||
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
|
||||
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
|
||||
Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B,
|
||||
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
||||
};
|
||||
// Returns canonical model name without the PromptWrapping suffix. This is used
|
||||
// in Specifier and thus does not change.
|
||||
const char* ModelPrefix(Model model);
|
||||
|
||||
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
|
||||
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
|
||||
static inline bool IsVLM(Model model) {
|
||||
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
|
||||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
|
||||
}
|
||||
|
||||
static inline bool IsPaliGemma(Model model) {
|
||||
if (model == Model::PALIGEMMA_224 || model == Model::PALIGEMMA_448 ||
|
||||
model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
||||
model == Model::PALIGEMMA2_10B_224 ||
|
||||
model == Model::PALIGEMMA2_10B_448) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
|
||||
template <class Func>
|
||||
void ForEachModel(const Func& func) {
|
||||
for (size_t i = static_cast<size_t>(Model::UNKNOWN) + 1;
|
||||
|
|
@ -218,24 +216,20 @@ static inline bool EnumValid(Model model) {
|
|||
return false;
|
||||
}
|
||||
|
||||
struct InternalLayerConfig : public IFields {
|
||||
const char* Name() const override { return "InternalLayerConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
};
|
||||
|
||||
// Per-layer configuration.
|
||||
struct LayerConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const LayerConfig& other, bool partial, bool debug) const;
|
||||
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
const char* Name() const override { return "LayerConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_dim);
|
||||
visitor(griffin_dim);
|
||||
|
|
@ -252,35 +246,45 @@ struct LayerConfig : public IFields {
|
|||
visitor(activation);
|
||||
visitor(post_qk);
|
||||
visitor(use_qk_norm);
|
||||
internal.VisitFields(visitor);
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
// Returns whether all fields match.
|
||||
bool TestEqual(const LayerConfig& other, bool print) const;
|
||||
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t griffin_dim = 0;
|
||||
uint32_t ff_hidden_dim = 0;
|
||||
uint32_t heads = 0;
|
||||
uint32_t kv_heads = 0;
|
||||
uint32_t qkv_dim = 0;
|
||||
uint32_t conv1d_width = 0; // griffin only
|
||||
uint32_t conv1d_width = 0; // Griffin only
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false;
|
||||
bool optimized_gating = true;
|
||||
bool softmax_attn_output_biases = false; // for Griffin
|
||||
bool optimized_gating = true; // for Gemma3
|
||||
PostNormType post_norm = PostNormType::None;
|
||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
ActivationType activation = ActivationType::Gelu;
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
bool use_qk_norm = false;
|
||||
InternalLayerConfig internal;
|
||||
};
|
||||
|
||||
// Dimensions related to image processing.
|
||||
struct VitConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
|
||||
|
||||
const char* Name() const override { return "VitConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_dim);
|
||||
visitor(seq_len);
|
||||
|
|
@ -289,8 +293,12 @@ struct VitConfig : public IFields {
|
|||
visitor(image_size);
|
||||
visitor(layer_configs);
|
||||
visitor(pool_dim);
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
// Returns whether all fields match.
|
||||
bool TestEqual(const VitConfig& other, bool print) const;
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t seq_len = 0;
|
||||
uint32_t num_scales = 0;
|
||||
|
|
@ -300,20 +308,93 @@ struct VitConfig : public IFields {
|
|||
std::vector<LayerConfig> layer_configs;
|
||||
};
|
||||
|
||||
// Returns a valid `PromptWrapping` for the given `model`, for passing to the
|
||||
// `ModelConfig` ctor when the caller does not care about the wrapping. The
|
||||
// wrapping mode is either determined by the model (for PaliGemma and Gemma3),
|
||||
// or defaults to IT, subject to user override for PT.
|
||||
PromptWrapping ChooseWrapping(Model model,
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
|
||||
struct InternalModelConfig : public IFields {
|
||||
const char* Name() const override { return "InternalModelConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
};
|
||||
|
||||
struct ModelConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const ModelConfig& other, bool partial, bool debug) const;
|
||||
// Preferred usage (single-file format): default-construct, then deserialize
|
||||
// from a blob. Also used by `config_converter.py`, which sets sufficient
|
||||
// fields for `TestEqual` and then calls `OverwriteWithCanonical()`.
|
||||
ModelConfig() = default;
|
||||
// For use by `backprop/`, and `model_store.cc` for pre-2025 format after
|
||||
// deducing the model from tensors plus a user-specified `wrapping` override
|
||||
// (see `ChooseWrapping`).
|
||||
ModelConfig(Model model, Type weight, PromptWrapping wrapping);
|
||||
// Parses a string returned by `Specifier()`. Used by the exporter to select
|
||||
// the model from command line arguments. Do not use this elsewhere - the
|
||||
// second ctor is preferred because it is type-checked.
|
||||
ModelConfig(const std::string& specifier);
|
||||
|
||||
const char* Name() const override { return "ModelConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_family_version);
|
||||
visitor(display_name);
|
||||
visitor(model);
|
||||
visitor(wrapping);
|
||||
visitor(weight);
|
||||
|
||||
visitor(num_layers);
|
||||
visitor(model_dim);
|
||||
visitor(vocab_size);
|
||||
visitor(seq_len);
|
||||
|
||||
visitor(unused_num_tensor_scales);
|
||||
|
||||
visitor(att_cap);
|
||||
visitor(final_cap);
|
||||
|
||||
visitor(absolute_pe);
|
||||
visitor(use_local_attention);
|
||||
visitor(query_scale);
|
||||
visitor(layer_configs);
|
||||
visitor(attention_window_sizes);
|
||||
visitor(norm_num_groups);
|
||||
visitor(vit_config);
|
||||
visitor(pool_dim);
|
||||
|
||||
visitor(eos_id);
|
||||
visitor(secondary_eos_id);
|
||||
|
||||
visitor(scale_base_names);
|
||||
|
||||
internal.VisitFields(visitor);
|
||||
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
// Returns whether all fields match except `model` and `display_name`, and
|
||||
// some others that are not yet set by config_converter.py. This is for
|
||||
// internal use by `OverwriteWithCanonical`, but potentially useful elsewhere.
|
||||
bool TestEqual(const ModelConfig& other, bool print) const;
|
||||
|
||||
// For each model, constructs its canonical `ModelConfig` and if `TestEqual`
|
||||
// returns true, overwrites `*this` with that. Otherwise, returns false to
|
||||
// indicate this is not a known model. Called by `config_converter.py`.
|
||||
bool OverwriteWithCanonical();
|
||||
|
||||
// Returns a string encoding of the model family, size, weight, and
|
||||
// `PromptWrapping`. Stable/unchanging; can be used as the model file name.
|
||||
// The third ctor also expects a string returned by this.
|
||||
std::string Specifier() const;
|
||||
|
||||
void AddLayerConfig(const LayerConfig& layer_config) {
|
||||
layer_configs.push_back(layer_config);
|
||||
}
|
||||
|
||||
size_t CachePosSize() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
HWY_ASSERT(layer_configs.size() <= num_layers);
|
||||
}
|
||||
|
||||
size_t NumLayersOfTypeBefore(LayerAttentionType type, size_t num) const {
|
||||
|
|
@ -336,72 +417,71 @@ struct ModelConfig : public IFields {
|
|||
return num_heads;
|
||||
}
|
||||
|
||||
const char* Name() const override { return "ModelConfig"; }
|
||||
size_t CachePosSize() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
}
|
||||
|
||||
bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); }
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_family_version);
|
||||
visitor(model_name);
|
||||
visitor(model);
|
||||
visitor(wrapping);
|
||||
visitor(weight);
|
||||
visitor(num_layers);
|
||||
visitor(model_dim);
|
||||
visitor(vocab_size);
|
||||
visitor(seq_len);
|
||||
visitor(num_tensor_scales);
|
||||
visitor(att_cap);
|
||||
visitor(final_cap);
|
||||
visitor(absolute_pe);
|
||||
visitor(use_local_attention);
|
||||
visitor(query_scale);
|
||||
visitor(layer_configs);
|
||||
visitor(attention_window_sizes);
|
||||
visitor(norm_num_groups);
|
||||
visitor(vit_config);
|
||||
visitor(pool_dim);
|
||||
visitor(eos_id);
|
||||
visitor(secondary_eos_id);
|
||||
}
|
||||
|
||||
// Major version of the model family. It is used as a fallback to distinguish
|
||||
// between model types when there is no explicit information in the config.
|
||||
// Major version of the model family, reflecting architecture changes. This is
|
||||
// more convenient to compare than `Model` because that also includes the
|
||||
// model size.
|
||||
uint32_t model_family_version = 1;
|
||||
std::string model_name;
|
||||
Model model = Model::UNKNOWN;
|
||||
// For display only, may change. Use `Specifier()` for setting the
|
||||
// file name. Not checked by `TestEqual` because `config_converter.py` does
|
||||
// not set this.
|
||||
std::string display_name;
|
||||
Model model = Model::UNKNOWN; // Not checked by `TestEqual`, see above.
|
||||
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
|
||||
Type weight = Type::kUnknown;
|
||||
|
||||
uint32_t num_layers = 0;
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t vocab_size = 0;
|
||||
uint32_t seq_len = 0;
|
||||
uint32_t num_tensor_scales = 0;
|
||||
|
||||
// We no longer set nor use this: config_converter is not able to set this,
|
||||
// and only pre-2025 format stores scales, and we do not require advance
|
||||
// knowledge of how many there will be. Any scales present will just be
|
||||
// assigned in order to the tensors matching `scale_base_names`.
|
||||
uint32_t unused_num_tensor_scales = 0;
|
||||
|
||||
float att_cap = 0.0f;
|
||||
float final_cap = 0.0f;
|
||||
|
||||
bool absolute_pe = false;
|
||||
bool use_local_attention = false; // griffin only
|
||||
bool use_local_attention = false; // Griffin only
|
||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||
std::vector<LayerConfig> layer_configs;
|
||||
std::vector<uint32_t> attention_window_sizes;
|
||||
std::unordered_set<std::string> scale_names;
|
||||
uint32_t norm_num_groups = 1;
|
||||
|
||||
// Dimensions related to image processing.
|
||||
VitConfig vit_config;
|
||||
uint32_t pool_dim = 1; // used only for VitConfig copy
|
||||
|
||||
int eos_id = 1;
|
||||
int secondary_eos_id = 1;
|
||||
|
||||
// Tensor base names without a layer suffix, used by `ModelStore` only for
|
||||
// pre-2025 format.
|
||||
std::vector<std::string> scale_base_names;
|
||||
|
||||
InternalModelConfig internal;
|
||||
};
|
||||
|
||||
// Returns the config for the given model.
|
||||
ModelConfig ConfigFromModel(Model model);
|
||||
|
||||
// Returns the model for the given config, if it matches any standard model.
|
||||
Model ModelFromConfig(const ModelConfig& config);
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
ModelConfig GetVitConfig(const ModelConfig& config);
|
||||
|
||||
enum DeducedLayerTypes {
|
||||
kDeducedGriffin = 1,
|
||||
kDeducedViT = 2,
|
||||
};
|
||||
|
||||
// layer_types is one or more of `DeducedLayerTypes`.
|
||||
Model DeduceModel(size_t layers, int layer_types);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||
|
|
|
|||
|
|
@ -1,461 +1,44 @@
|
|||
#include "gemma/configs.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "compression/fields.h" // Type
|
||||
#include "compression/shared.h" // Type
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<LayerAttentionType, kNum> OldFixedLayerConfig(
|
||||
LayerAttentionType type) {
|
||||
std::array<LayerAttentionType, kNum> config = {};
|
||||
for (LayerAttentionType& l : config) {
|
||||
l = type;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
TEST(ConfigsTest, TestAll) {
|
||||
ForEachModel([&](Model model) {
|
||||
ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
|
||||
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
|
||||
config.Specifier().c_str());
|
||||
HWY_ASSERT(config.model == model);
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<size_t, kNum> OldFixedAttentionWindowSizes(
|
||||
size_t window_size) {
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t& l : window_size_configs) {
|
||||
l = window_size;
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
// We can deduce the model/display_name from all other fields.
|
||||
config.model = Model::UNKNOWN;
|
||||
const std::string saved_display_name = config.display_name;
|
||||
config.display_name.clear();
|
||||
HWY_ASSERT(config.OverwriteWithCanonical());
|
||||
HWY_ASSERT(config.model == model);
|
||||
HWY_ASSERT(config.display_name == saved_display_name);
|
||||
|
||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||
template <size_t kNum, size_t kPatternSize>
|
||||
constexpr std::array<size_t, kNum> OldRepeatedAttentionWindowSizes(
|
||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||
static_assert(kNum % kPatternSize == 0,
|
||||
"kNum must be a multiple of kPatternSize");
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t OldNumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
LayerAttentionType type, size_t num) {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
if (layers[i] == type) count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CacheLayerSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CachePosSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
|
||||
}
|
||||
};
|
||||
|
||||
struct OldConfigNoVit {
|
||||
struct VitConfig {
|
||||
// Some of these are needed to make the compiler happy when trying to
|
||||
// generate code that will actually never be used.
|
||||
using Weight = float;
|
||||
static constexpr int kLayers = 0;
|
||||
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
|
||||
OldFixedLayerConfig<0>(LayerAttentionType::kVit);
|
||||
static constexpr int kModelDim = 0;
|
||||
static constexpr int kFFHiddenDim = 0;
|
||||
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
|
||||
static constexpr int kKVHeads = 0;
|
||||
static constexpr int kQKVDim = 0;
|
||||
static constexpr int kSeqLen = 0;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
struct OldConfigNoSSM : OldConfigNoVit {
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
struct OldConfigBaseGemmaV1 : OldConfigNoSSM {
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
struct OldConfigBaseGemmaV2 : OldConfigNoSSM {
|
||||
static constexpr float kAttCap = 50.0f;
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::Scale;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_27B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
OldFixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
|
||||
static constexpr int kHeads = 32;
|
||||
static constexpr int kKVHeads = 16;
|
||||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale =
|
||||
QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_9B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
OldFixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 8;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma7B : public OldConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
OldFixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<28>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2B : public OldConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
OldFixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<18>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigPaliGemma_224 : public OldConfigGemma2B<TWeight> {
|
||||
// On the LM side, the vocab size is one difference to Gemma1-2B in the
|
||||
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
|
||||
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
|
||||
|
||||
// Sub-config for the Vision-Transformer part.
|
||||
struct VitConfig : public OldConfigNoSSM {
|
||||
using Weight = TWeight;
|
||||
// The ViT parts. https://arxiv.org/abs/2305.13035
|
||||
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
|
||||
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
|
||||
OldFixedLayerConfig<27>(LayerAttentionType::kVit);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kModelDim = 1152;
|
||||
static constexpr int kFFHiddenDim = 4304;
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 72;
|
||||
static constexpr int kSeqLen = 16 * 16; // 256
|
||||
static constexpr bool kFFBiases = true;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
static constexpr int kVocabSize = 0;
|
||||
// Dimensions related to image processing.
|
||||
static constexpr int kPatchWidth = 14;
|
||||
static constexpr int kImageSize = 224;
|
||||
// Necessary constant for the layer configuration.
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_2B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
|
||||
OldFixedLayerConfig<26>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2304;
|
||||
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 4;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemmaTiny : public OldConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 32;
|
||||
static constexpr int kVocabSize = 64;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
OldFixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<3>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 128;
|
||||
static constexpr int kFFHiddenDim = 256;
|
||||
static constexpr int kHeads = 4;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 16; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGriffin2B : OldConfigNoVit {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
static constexpr int kSeqLen = 2048;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
};
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<26>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = OldNumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
static constexpr int kGriffinLayers = OldNumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
|
||||
static constexpr int kModelDim = 2560;
|
||||
static constexpr int kFFHiddenDim = 7680;
|
||||
static constexpr int kHeads = 10;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
// No SoftCap.
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
static constexpr bool kFFBiases = true;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = true;
|
||||
static constexpr bool kUseHalfRope = true;
|
||||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
void AssertMatch(const ModelConfig& config) {
|
||||
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
|
||||
if constexpr (TConfig::VitConfig::kModelDim != 0) {
|
||||
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
|
||||
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
|
||||
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
|
||||
config.vit_config.num_scales);
|
||||
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
|
||||
config.vit_config.layer_configs[i].type);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
||||
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
|
||||
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
|
||||
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
|
||||
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
|
||||
ASSERT_EQ(TConfig::kUseLocalAttention, config.use_local_attention);
|
||||
ASSERT_EQ(TConfig::kQueryScale, config.query_scale);
|
||||
ASSERT_EQ(TConfig::kGemmaLayers,
|
||||
config.NumLayersOfType(LayerAttentionType::kGemma));
|
||||
ASSERT_EQ(TConfig::kGriffinLayers,
|
||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock));
|
||||
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::kModelDim, config.layer_configs[i].model_dim);
|
||||
ASSERT_EQ(TConfig::kFFHiddenDim, config.layer_configs[i].ff_hidden_dim);
|
||||
ASSERT_EQ(TConfig::kHeads, config.layer_configs[i].heads);
|
||||
ASSERT_EQ(TConfig::kKVHeads, config.layer_configs[i].kv_heads);
|
||||
ASSERT_EQ(TConfig::kQKVDim, config.layer_configs[i].qkv_dim);
|
||||
ASSERT_EQ(TConfig::kConv1dWidth, config.layer_configs[i].conv1d_width);
|
||||
ASSERT_EQ(TConfig::kFFBiases, config.layer_configs[i].ff_biases);
|
||||
ASSERT_EQ(TConfig::kSoftmaxAttnOutputBiases,
|
||||
config.layer_configs[i].softmax_attn_output_biases);
|
||||
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
|
||||
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
|
||||
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
|
||||
PostQKType post_qk = TConfig::kPostQK;
|
||||
if (TConfig::kUseHalfRope) {
|
||||
post_qk = PostQKType::HalfRope;
|
||||
}
|
||||
ASSERT_EQ(post_qk, config.layer_configs[i].post_qk);
|
||||
}
|
||||
|
||||
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
|
||||
config.attention_window_sizes.size());
|
||||
for (size_t i = 0; i < config.attention_window_sizes.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::kAttentionWindowSizes[i],
|
||||
config.attention_window_sizes[i]);
|
||||
}
|
||||
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
|
||||
}
|
||||
|
||||
ModelConfig RoundTripSerialize(const ModelConfig& config) {
|
||||
std::vector<uint32_t> config_buffer = config.Write();
|
||||
ModelConfig deserialized;
|
||||
deserialized.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||
return deserialized;
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2B) {
|
||||
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
|
||||
ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B));
|
||||
AssertMatch<OldConfigGemma2B<float>>(config);
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma7B) {
|
||||
AssertMatch<OldConfigGemma7B<float>>(ConfigFromModel(Model::GEMMA_7B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_2B) {
|
||||
AssertMatch<OldConfigGemma2_2B<float>>(ConfigFromModel(Model::GEMMA2_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_9B) {
|
||||
AssertMatch<OldConfigGemma2_9B<float>>(ConfigFromModel(Model::GEMMA2_9B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_27B) {
|
||||
AssertMatch<OldConfigGemma2_27B<float>>(ConfigFromModel(Model::GEMMA2_27B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGriffin2B) {
|
||||
AssertMatch<OldConfigGriffin2B<float>>(ConfigFromModel(Model::GRIFFIN_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemmaTiny) {
|
||||
AssertMatch<OldConfigGemmaTiny<float>>(ConfigFromModel(Model::GEMMA_TINY));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigPaliGemma_224) {
|
||||
AssertMatch<OldConfigPaliGemma_224<float>>(
|
||||
ConfigFromModel(Model::PALIGEMMA_224));
|
||||
const std::vector<uint32_t> serialized = config.Write();
|
||||
ModelConfig deserialized;
|
||||
const IFields::ReadResult result =
|
||||
deserialized.Read(hwy::Span<const uint32_t>(serialized), /*pos=*/0);
|
||||
HWY_ASSERT(result.pos == serialized.size());
|
||||
// We wrote it, so all fields should be known, and no extra.
|
||||
HWY_ASSERT(result.extra_u32 == 0);
|
||||
HWY_ASSERT(result.missing_fields == 0);
|
||||
// All fields should match.
|
||||
HWY_ASSERT(deserialized.TestEqual(config, /*print=*/true));
|
||||
HWY_ASSERT(deserialized.model == model);
|
||||
HWY_ASSERT(deserialized.display_name == saved_display_name);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
|
|
@ -305,13 +305,7 @@ class GemmaAttention {
|
|||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
||||
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
} else {
|
||||
MatVec(layer_weights_.qkv_einsum_w2, 0, //
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
}
|
||||
MatVec(w_q2, w_q2.ofs, w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
}
|
||||
}
|
||||
} // !is_mha_
|
||||
|
|
@ -781,7 +775,6 @@ template <typename T>
|
|||
HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
|
|
@ -917,8 +910,16 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
|||
HWY_DASSERT(token < static_cast<int>(vocab_size));
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, weights.embedder_input_embedding.Span(),
|
||||
token * model_dim, x.Batch(batch_idx), model_dim);
|
||||
// Using `Stride` to compute the offset works for both NUQ (because we use an
|
||||
// offset and NUQ is never padded) and padded, because non-NUQ types are
|
||||
// seekable, hence the offset can also skip any padding.
|
||||
const size_t embedding_ofs =
|
||||
token * weights.embedder_input_embedding.Stride();
|
||||
HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim);
|
||||
const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0),
|
||||
embedding_ofs + model_dim);
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Batch(batch_idx),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
|
||||
x.Batch(batch_idx), model_dim);
|
||||
if (weights.weights_config.absolute_pe) {
|
||||
|
|
@ -1128,7 +1129,7 @@ HWY_NOINLINE void Prefill(
|
|||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0;
|
||||
layer < weights.weights_config.layer_configs.size(); ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
const LayerWeightsPtrs<T>* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer(single_query_pos, single_query_prefix_end, tbatch_size,
|
||||
layer, layer_weights, activations, div_seq_len,
|
||||
single_kv_cache);
|
||||
|
|
@ -1222,7 +1223,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
for (size_t layer = 0;
|
||||
layer < weights.weights_config.vit_config.layer_configs.size();
|
||||
++layer) {
|
||||
const auto* layer_weights = weights.GetVitLayer(layer);
|
||||
const LayerWeightsPtrs<T>* layer_weights = weights.VitLayer(layer);
|
||||
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
|
||||
}
|
||||
// Final Layernorm.
|
||||
|
|
@ -1359,7 +1360,7 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
template <typename T>
|
||||
// Runs one decode step for all the queries in the batch. Returns true if all
|
||||
// queries are at <end_of_sentence>.
|
||||
bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
||||
bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
|
|
@ -1398,7 +1399,7 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
|||
token_streamer(query_idx_start + query_idx,
|
||||
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
||||
all_queries_eos &= is_eos;
|
||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
|
||||
gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token;
|
||||
}
|
||||
return all_queries_eos;
|
||||
}
|
||||
|
|
@ -1415,8 +1416,8 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
|||
//
|
||||
// `kv_caches` is for the batch, size must match `queries_prompt`.
|
||||
template <typename T>
|
||||
void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights,
|
||||
Activations& activations, const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
|
|
@ -1438,7 +1439,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != model.Config().eos_id);
|
||||
}
|
||||
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
|
|
@ -1447,7 +1448,6 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
HWY_ASSERT(queries_pos_in.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
|
||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
|
||||
|
|
@ -1497,9 +1497,9 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||
bool all_queries_eos = DecodeStepT<T>(
|
||||
weights, runtime_config, queries_prompt, query_idx_start, kv_caches,
|
||||
queries_prefix_end, div_seq_len, vocab_size, sample_token,
|
||||
activations, token_streamer, gen_tokens,
|
||||
model.Config(), weights, runtime_config, queries_prompt,
|
||||
query_idx_start, kv_caches, queries_prefix_end, div_seq_len,
|
||||
vocab_size, sample_token, activations, token_streamer, gen_tokens,
|
||||
timing_info, queries_mutable_pos);
|
||||
if (all_queries_eos) break;
|
||||
} // foreach token to generate
|
||||
|
|
@ -1508,7 +1508,8 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateSingleT(const ModelWeightsStorage& model,
|
||||
void GenerateSingleT(const ModelStore2& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
|
|
@ -1525,12 +1526,14 @@ void GenerateSingleT(const ModelWeightsStorage& model,
|
|||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT<T>(model, activations, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, qbatch_start, kv_caches, timing_info);
|
||||
GenerateT<T>(model, weights, activations, runtime_config, queries_prompt,
|
||||
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateBatchT(const ModelWeightsStorage& model,
|
||||
void GenerateBatchT(const ModelStore2& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
|
|
@ -1542,7 +1545,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
|
|||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
// Griffin does not support query batching.
|
||||
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||
for (const auto& layer_config : model.Config().layer_configs) {
|
||||
for (const LayerConfig& layer_config : model.Config().layer_configs) {
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
max_qbatch_size = 1;
|
||||
break;
|
||||
|
|
@ -1563,13 +1566,15 @@ void GenerateBatchT(const ModelWeightsStorage& model,
|
|||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
|
||||
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);
|
||||
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
|
||||
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
|
||||
timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateImageTokensT(const ModelWeightsStorage& model,
|
||||
void GenerateImageTokensT(const ModelStore2& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
MatMulEnv* env) {
|
||||
|
|
@ -1583,8 +1588,8 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
|||
Activations prefill_activations(vit_config);
|
||||
prefill_activations.Allocate(vit_config.seq_len, env);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
|
||||
image_tokens, prefill_activations);
|
||||
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -1592,33 +1597,34 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
|||
#if HWY_ONCE
|
||||
|
||||
// These are extern functions defined by instantiations/*.cc, which include this
|
||||
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
||||
// 'header' after defining `GEMMA_TYPE`.
|
||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_TYPE, const ModelWeightsStorage& model,
|
||||
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_TYPE>)
|
||||
(model, runtime_config, prompt, pos, prefix_end, kv_cache, env, timing_info);
|
||||
(model, weights, runtime_config, prompt, pos, prefix_end, kv_cache, env,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_TYPE, const ModelWeightsStorage& model,
|
||||
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
|
||||
MatMulEnv* env, TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_TYPE>)
|
||||
(model, runtime_config, queries_prompt, queries_pos, queries_prefix_end,
|
||||
kv_caches, env, timing_info);
|
||||
(model, weights, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, kv_caches, env, timing_info);
|
||||
}
|
||||
|
||||
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_TYPE, const ModelWeightsStorage& model,
|
||||
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv* env) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)
|
||||
(model, runtime_config, image, image_tokens, env);
|
||||
(model, weights, runtime_config, image, image_tokens, env);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
165
gemma/gemma.cc
165
gemma/gemma.cc
|
|
@ -23,14 +23,16 @@
|
|||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
|
|
@ -40,8 +42,8 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Internal init must run before I/O; calling it from `GemmaEnv()` is too late.
|
||||
// This helper function takes care of the internal init plus calling `SetArgs`.
|
||||
// Internal init must run before I/O. This helper function takes care of that,
|
||||
// plus calling `SetArgs`.
|
||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
||||
// Placeholder for internal init, do not modify.
|
||||
|
||||
|
|
@ -49,102 +51,72 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
|||
return MatMulEnv(ThreadingContext2::Get());
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, MatMulEnv& env)
|
||||
: env_(env), tokenizer_(tokenizer_path) {
|
||||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||
env_.ctx.pools.Pool(0),
|
||||
/*tokenizer_proto=*/nullptr);
|
||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
||||
std::string tokenizer_proto;
|
||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||
env_.ctx.pools.Pool(0), &tokenizer_proto);
|
||||
tokenizer_.Deserialize(tokenizer_proto);
|
||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
||||
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
||||
: env_(env),
|
||||
tokenizer_(std::move(tokenizer)),
|
||||
chat_template_(tokenizer_, info.model) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
model_.Allocate(info.model, info.weight, env_.ctx.pools.Pool(0));
|
||||
reader_(BlobReader2::Make(loader.weights, loader.map)),
|
||||
model_(*reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config().weight),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
weights_.ReadOrAllocate(model_, *reader_, env_.ctx.pools.Pool());
|
||||
reader_.reset();
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer,
|
||||
MatMulEnv& env)
|
||||
: env_(env),
|
||||
model_(config, std::move(tokenizer)),
|
||||
weights_(config.weight),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
HWY_ASSERT(config.weight == Type::kF32);
|
||||
weights_.AllocateForTest(config, env_.ctx.pools.Pool(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() = default;
|
||||
|
||||
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
|
||||
BlobWriter2 writer;
|
||||
const std::vector<uint32_t> serialized_mat_ptrs =
|
||||
weights_.AddTensorDataToWriter(writer);
|
||||
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
|
||||
writer, env_.ctx.pools.Pool(), weights_path);
|
||||
}
|
||||
|
||||
// There are >=3 types of the inference code. To reduce compile time,
|
||||
// we shard them across multiple translation units in instantiations/*.cc.
|
||||
// This declares the functions defined there. We use overloading because
|
||||
// explicit instantiations are still too slow to compile.
|
||||
#define GEMMA_DECLARE(TWEIGHT) \
|
||||
extern void GenerateSingle(TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
size_t prefix_end, KVCache& kv_cache, \
|
||||
MatMulEnv* env, TimingInfo& timing_info); \
|
||||
// TODO: we want to move toward type-erasing, where we check the tensor type at
|
||||
// each usage. Then we would have a single function, passing `WeightsOwner`
|
||||
// instead of `WeightsPtrs<T>`.
|
||||
#define GEMMA_DECLARE(WEIGHT_TYPE) \
|
||||
extern void GenerateSingle( \
|
||||
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, \
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
|
||||
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens(TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const Image& image, \
|
||||
ImageTokens& image_tokens, MatMulEnv* env);
|
||||
extern void GenerateImageTokens( \
|
||||
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, MatMulEnv* env);
|
||||
GEMMA_DECLARE(float)
|
||||
GEMMA_DECLARE(BF16)
|
||||
GEMMA_DECLARE(NuqStream)
|
||||
GEMMA_DECLARE(SfpStream)
|
||||
|
||||
// Adapters to select from the above overloads via CallForModelWeight.
|
||||
template <class TConfig>
|
||||
struct GenerateSingleT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, env, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateBatchT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, MatMulEnv* env,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, kv_caches, env, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateImageTokensT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv* env) const {
|
||||
GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
|
||||
env);
|
||||
}
|
||||
};
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
KVCache& kv_cache, TimingInfo& timing_info) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
|
||||
weights_.CallT([&](auto& weights) {
|
||||
GenerateSingle(model_, *weights, runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, &env_, timing_info);
|
||||
});
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
@ -153,11 +125,12 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||
const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) const {
|
||||
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
||||
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
|
||||
std::vector<size_t> prefix_end_vec;
|
||||
if (queries_prefix_end.size() == 0) {
|
||||
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
|
||||
prefix_end_vec.resize(queries_prompt.size(), 0);
|
||||
mutable_queries_prefix_end =
|
||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
|
|
@ -165,36 +138,26 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, &env_, timing_info);
|
||||
weights_.CallT([&](auto& weights) {
|
||||
gcpp::GenerateBatch(model_, *weights, runtime_config, queries_prompt,
|
||||
queries_pos, mutable_queries_prefix_end, kv_caches,
|
||||
&env_, timing_info);
|
||||
});
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
const Image& image,
|
||||
ImageTokens& image_tokens) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, &env_);
|
||||
weights_.CallT([&](auto& weights) {
|
||||
gcpp::GenerateImageTokens(model_, *weights, runtime_config, image,
|
||||
image_tokens, &env_);
|
||||
});
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
||||
if (!weights_config.use_local_attention) {
|
||||
if (max_generated_tokens > weights_config.seq_len) {
|
||||
fprintf(stderr,
|
||||
"WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n",
|
||||
max_generated_tokens, weights_config.seq_len);
|
||||
max_generated_tokens = weights_config.seq_len;
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
161
gemma/gemma.h
161
gemma/gemma.h
|
|
@ -18,18 +18,16 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
|
|
@ -38,104 +36,8 @@
|
|||
#include "util/threading_context.h"
|
||||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
|
||||
namespace gcpp {
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens, probability is 0.0f.
|
||||
// StreamFunc should return false to stop generation and true to continue.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
// If not empty, SampleFunc is called with the logits for the next token, which
|
||||
// it may modify/overwrite, and its return value is the next generated token
|
||||
// together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, e.g. "tokens" for token IDs
|
||||
// - layer index (or -1 for global outputs)
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||
int, const float*, size_t)>;
|
||||
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
||||
// - per-query position within the tokens sequence
|
||||
// - layer index (or -1 for post-norm output)
|
||||
// - activations
|
||||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
|
||||
// corresponds to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = RowVectorBatch<float>;
|
||||
|
||||
// RuntimeConfig holds configuration for a single generation run.
|
||||
struct RuntimeConfig {
|
||||
// If not empty, batch_stream_token is called for each token in the batch,
|
||||
// instead of stream_token.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
if (batch_stream_token) {
|
||||
return batch_stream_token(query_idx, pos, token, prob);
|
||||
}
|
||||
return stream_token(token, prob);
|
||||
}
|
||||
|
||||
// Limit on the number of tokens generated.
|
||||
size_t max_generated_tokens;
|
||||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
// Max tokens per batch during prefill.
|
||||
size_t prefill_tbatch_size = 256;
|
||||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
|
||||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
size_t top_k = kTopK; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
|
||||
// Observer callbacks for intermediate data.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
||||
|
||||
// If not empty, these point to the image tokens and are used in the
|
||||
// PaliGemma prefix-LM style attention.
|
||||
const ImageTokens *image_tokens = nullptr;
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
||||
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
||||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
|
||||
// End-of-sequence token.
|
||||
int eos_id = EOS_ID;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||
|
|
@ -196,58 +98,52 @@ struct TimingInfo {
|
|||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
// Internal init must run before I/O; calling it from GemmaEnv() is too late.
|
||||
// This helper function takes care of the internal init plus calling `SetArgs`.
|
||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
||||
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads old format weights file and tokenizer file.
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
MatMulEnv& env);
|
||||
// Reads new format weights file that contains everything in a single file.
|
||||
Gemma(const LoaderArgs& loader, MatMulEnv& env);
|
||||
|
||||
// Only allocates weights, caller is responsible for filling them. Only used
|
||||
// by `optimize_test.cc`.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const Path& weights, MatMulEnv& env);
|
||||
// Allocates weights, caller is responsible for filling them.
|
||||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
|
||||
Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer, MatMulEnv& env);
|
||||
|
||||
~Gemma();
|
||||
|
||||
MatMulEnv& Env() const { return env_; }
|
||||
// TODO: rename to Config()
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
// DEPRECATED
|
||||
ModelInfo Info() const {
|
||||
return ModelInfo({.model = model_.Config().model,
|
||||
.wrapping = model_.Config().wrapping,
|
||||
.weight = model_.Config().weight});
|
||||
}
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
const WeightsOwner& Weights() const { return weights_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const ModelWeightsStorage& Weights() const { return model_; }
|
||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
||||
std::string tokenizer_proto = tokenizer_.Serialize();
|
||||
model_.Save(tokenizer_proto, weights, pool);
|
||||
}
|
||||
|
||||
// For tests.
|
||||
WeightsOwner& MutableWeights() { return weights_; }
|
||||
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) const {
|
||||
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
|
||||
timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the end of the prefix.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
||||
TimingInfo& timing_info);
|
||||
TimingInfo& timing_info) const;
|
||||
|
||||
// `queries_pos` are the positions in the KV cache. Users are responsible for
|
||||
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(runtime_config, queries_prompt, queries_pos,
|
||||
/*queries_prefix_end=*/{}, kv_caches, timing_info);
|
||||
}
|
||||
|
|
@ -256,19 +152,18 @@ class Gemma {
|
|||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) const;
|
||||
|
||||
// Generates the image tokens by running the image encoder ViT.
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens);
|
||||
const Image& image, ImageTokens& image_tokens) const;
|
||||
|
||||
private:
|
||||
MatMulEnv& env_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
std::unique_ptr<BlobReader2> reader_; // null for second ctor
|
||||
ModelStore2 model_;
|
||||
WeightsOwner weights_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
ModelWeightsStorage model_;
|
||||
};
|
||||
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
|
|
|
|||
|
|
@ -21,125 +21,144 @@
|
|||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "ops/matmul.h"
|
||||
#include "ops/matmul.h" // MMStorage::kMax*
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||
InitAndParse(argc, argv);
|
||||
// Allow changing k parameter of `SampleTopK` as a compiler flag
|
||||
#ifndef GEMMA_TOPK
|
||||
#define GEMMA_TOPK 1
|
||||
#endif // !GEMMA_TOPK
|
||||
|
||||
if (validate) {
|
||||
if (const char* error = Validate()) {
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||
const std::string& model, bool validate = true) {
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(const std::string& tokenizer_path,
|
||||
const std::string& weights_path) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
tokenizer.path = tokenizer_path;
|
||||
weights.path = weights_path;
|
||||
model_type_str = model;
|
||||
|
||||
if (validate) {
|
||||
if (const char* error = Validate()) {
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the model weights.";
|
||||
}
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
info_.model = Model::UNKNOWN;
|
||||
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||
info_.weight = Type::kUnknown;
|
||||
if (!model_type_str.empty()) {
|
||||
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||
info_.wrapping);
|
||||
if (err != nullptr) return err;
|
||||
}
|
||||
if (!weight_type_str.empty()) {
|
||||
const char* err = ParseType(weight_type_str, info_.weight);
|
||||
if (err != nullptr) return err;
|
||||
}
|
||||
if (!tokenizer.path.empty()) {
|
||||
if (!tokenizer.Exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
}
|
||||
}
|
||||
// model_type and tokenizer must be either both present or both absent.
|
||||
// Further checks happen on weight loading.
|
||||
if (model_type_str.empty() != tokenizer.path.empty()) {
|
||||
return "Missing or extra flags for model_type or tokenizer.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Path compressed_weights;
|
||||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
Tristate map;
|
||||
Tristate wrapping;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file.");
|
||||
"Path name of tokenizer model; only required for pre-2025 format.");
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Deprecated alias for --weights.");
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type, see common.cc for valid values.\n");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||
visitor(map, "map", Tristate::kDefault,
|
||||
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
|
||||
visitor(wrapping, "wrapping", Tristate::kDefault,
|
||||
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
|
||||
private:
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
// `env` must remain valid for the lifetime of the Gemma.
|
||||
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// New weights file format doesn't need tokenizer path or model/weightinfo.
|
||||
return Gemma(loader.weights, env);
|
||||
}
|
||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
|
||||
}
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
||||
// `env` must remain valid for the lifetime of the Gemma.
|
||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||
MatMulEnv& env) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// New weights file format doesn't need tokenizer path or model/weight info.
|
||||
return std::make_unique<Gemma>(loader.weights, env);
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
|
||||
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
|
||||
// corresponds to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = RowVectorBatch<float>;
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens, probability is 0.0f.
|
||||
// StreamFunc should return false to stop generation and true to continue.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
// If not empty, SampleFunc is called with the logits for the next token, which
|
||||
// it may modify/overwrite, and its return value is the next generated token
|
||||
// together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, e.g. "tokens" for token IDs
|
||||
// - layer index (or -1 for global outputs)
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||
int, const float*, size_t)>;
|
||||
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
||||
// - per-query position within the tokens sequence
|
||||
// - layer index (or -1 for post-norm output)
|
||||
// - activations
|
||||
class Activations;
|
||||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
// RuntimeConfig holds configuration for a single generation run.
|
||||
struct RuntimeConfig {
|
||||
// If not empty, batch_stream_token is called for each token in the batch,
|
||||
// instead of stream_token.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
if (batch_stream_token) {
|
||||
return batch_stream_token(query_idx, pos, token, prob);
|
||||
}
|
||||
return stream_token(token, prob);
|
||||
}
|
||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||
loader.Info(), env);
|
||||
}
|
||||
|
||||
// Limit on the number of tokens generated.
|
||||
size_t max_generated_tokens;
|
||||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
// Max tokens per batch during prefill.
|
||||
size_t prefill_tbatch_size = 256;
|
||||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
|
||||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
|
||||
size_t top_k = GEMMA_TOPK; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
|
||||
// Observer callbacks for intermediate data.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
||||
|
||||
// If not empty, these point to the image tokens and are used in the
|
||||
// PaliGemma prefix-LM style attention.
|
||||
const ImageTokens* image_tokens = nullptr;
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
||||
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
||||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
};
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
|
@ -161,15 +180,6 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
std::string prompt; // Added prompt flag for non-interactive mode
|
||||
std::string eot_line;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||
return "max_generated_tokens is larger than the maximum sequence length "
|
||||
"(see configs.h).";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,418 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/model_store.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstdlib>
|
||||
#include <cstring> // strcmp
|
||||
#include <string>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/fields.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Single-file format contains blobs with these names:
|
||||
static constexpr char kConfigName[] = "config";
|
||||
static constexpr char kTokenizerName[] = "tokenizer";
|
||||
static constexpr char kMatPtrsName[] = "toc";
|
||||
// Pre-2025 format has one metadata blob. 'F' denoted f32.
|
||||
static constexpr char kDecoratedScalesName[] = "Fscales";
|
||||
|
||||
static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
|
||||
// No warning if missing_fields > 0: those fields are default-initialized.
|
||||
if (result.extra_u32) {
|
||||
HWY_WARN(
|
||||
"Serialized blob %s has %u extra fields the code is not aware of. "
|
||||
"Consider updating to the latest code from GitHub.",
|
||||
name, result.extra_u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the serialized tokenizer (std::string is required for proto).
|
||||
// Reads it from a blob or from a separate file if pre-2025.
|
||||
static std::string ReadTokenizer(BlobReader2& reader,
|
||||
const Path& tokenizer_path) {
|
||||
std::string tokenizer;
|
||||
// Check prevents `CallWithSpan` from printing a warning.
|
||||
if (reader.Find(kTokenizerName)) {
|
||||
if (!reader.CallWithSpan<char>(
|
||||
kTokenizerName, [&tokenizer](const hwy::Span<const char> bytes) {
|
||||
tokenizer.assign(bytes.data(), bytes.size());
|
||||
})) {
|
||||
HWY_WARN(
|
||||
"Reading tokenizer blob failed, please raise an issue. You can "
|
||||
"instead specify a tokenizer file via --tokenizer.");
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokenizer.empty() && tokenizer != kMockTokenizer) {
|
||||
return tokenizer; // Read actual tokenizer from blob.
|
||||
}
|
||||
|
||||
// No blob but user specified path to file: read it or abort.
|
||||
if (!tokenizer_path.Empty()) {
|
||||
return ReadFileToString(tokenizer_path);
|
||||
}
|
||||
|
||||
HWY_WARN(
|
||||
"BlobStore does not contain a tokenizer and no --tokenizer was "
|
||||
"specified. Tests may continue but inference will fail.");
|
||||
return kMockTokenizer;
|
||||
}
|
||||
|
||||
using KeyVec = std::vector<std::string>;
|
||||
|
||||
class TypePrefix {
|
||||
public:
|
||||
static Type TypeFromChar(char c) {
|
||||
switch (c) {
|
||||
case 'F':
|
||||
return Type::kF32;
|
||||
case 'B':
|
||||
return Type::kBF16;
|
||||
case '$':
|
||||
return Type::kSFP;
|
||||
case '2':
|
||||
return Type::kNUQ;
|
||||
default:
|
||||
// The other types were not written to pre-2025 files, hence no need to
|
||||
// encode and check for them here.
|
||||
return Type::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
TypePrefix(const KeyVec& keys, const BlobReader2& reader) {
|
||||
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
|
||||
const std::string& key = keys[key_idx];
|
||||
const Type type = TypeFromChar(key[0]);
|
||||
const uint64_t bytes = reader.Range(key_idx).bytes;
|
||||
bytes_[static_cast<size_t>(type)] += bytes;
|
||||
blobs_[static_cast<size_t>(type)]++;
|
||||
total_bytes_ += bytes;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true for pre-2025 format, which has type prefixes and thus the
|
||||
// functions below may be used.
|
||||
bool HasPrefixes() const {
|
||||
return bytes_[static_cast<size_t>(Type::kUnknown)] != total_bytes_;
|
||||
}
|
||||
|
||||
// Returns the weight type deduced from the histogram of blobs per type.
|
||||
// Rationale: We expect a mix of types due to varying precision requirements
|
||||
// for each tensor. The preferred weight type might not even be the most
|
||||
// common, because we prioritize higher compression for the *large* tensors.
|
||||
// Ignore types which only have a few blobs (might be metadata), and assume
|
||||
// that there would be at least 4 of the large tensors (in particular, global
|
||||
// attention layers). Hence return the smallest type with >= 4 blobs.
|
||||
Type DeduceWeightType() const {
|
||||
size_t min_bits = ~size_t{0};
|
||||
Type weight_type = Type::kUnknown;
|
||||
for (size_t i = 0; i < kNumTypes; ++i) {
|
||||
if (blobs_[i] < 4) continue;
|
||||
const size_t bits = TypeBits(static_cast<Type>(i));
|
||||
if (bits < min_bits) {
|
||||
min_bits = bits;
|
||||
weight_type = static_cast<Type>(i);
|
||||
}
|
||||
}
|
||||
return weight_type;
|
||||
}
|
||||
|
||||
// Prints statistics on the total size of tensors by type.
|
||||
void PrintTypeBytes() const {
|
||||
for (size_t type_idx = 0; type_idx < kNumTypes; ++type_idx) {
|
||||
const Type type = static_cast<Type>(type_idx);
|
||||
const uint64_t bytes = bytes_[type_idx];
|
||||
if (bytes == 0) continue;
|
||||
const double percent = 100.0 * bytes / total_bytes_;
|
||||
fprintf(stderr, "%zu blob bytes (%.2f%%) of %s\n",
|
||||
static_cast<size_t>(bytes), percent, TypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t total_bytes_ = 0;
|
||||
std::array<size_t, kNumTypes> bytes_{0};
|
||||
std::array<size_t, kNumTypes> blobs_{0};
|
||||
};
|
||||
|
||||
// Returns the number of layers based on the largest blob name suffix seen.
|
||||
// This works with or without type prefixes because it searches for suffixes.
|
||||
static size_t DeduceNumLayers(const KeyVec& keys) {
|
||||
size_t max_layer_idx = 0;
|
||||
for (const std::string& key : keys) {
|
||||
const size_t suffix_pos = key.rfind('_');
|
||||
if (suffix_pos == std::string::npos) continue;
|
||||
|
||||
char* end;
|
||||
auto layer_idx = strtoul(key.c_str() + suffix_pos + 1, &end, 10); // NOLINT
|
||||
HWY_ASSERT(layer_idx < 999); // Also checks for `ULONG_MAX` if out of range
|
||||
// Ignore if not a suffix. Some names are prefixed with "c_" for historical
|
||||
// reasons. In such cases, parsing layer_idx anyway returns 0.
|
||||
if (end - key.c_str() != key.size()) continue;
|
||||
|
||||
max_layer_idx = HWY_MAX(max_layer_idx, layer_idx);
|
||||
}
|
||||
return max_layer_idx + 1;
|
||||
}
|
||||
|
||||
// Looks for known tensor names associated with model families.
|
||||
// This works with or without type prefixes because it searches for substrings.
|
||||
static int DeduceLayerTypes(const KeyVec& keys) {
|
||||
int layer_types = 0;
|
||||
for (const std::string& key : keys) {
|
||||
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
|
||||
return kDeducedGriffin;
|
||||
}
|
||||
if (key.find("qkv_einsum_w") != std::string::npos) { // NOLINT
|
||||
layer_types |= kDeducedViT;
|
||||
}
|
||||
}
|
||||
return layer_types;
|
||||
}
|
||||
|
||||
// `wrapping_override` is forwarded from the command line. For pre-2025 files
|
||||
// without `ModelConfig`, it is the only way to force PT.
|
||||
static ModelConfig ReadOrDeduceConfig(BlobReader2& reader,
|
||||
Tristate wrapping_override) {
|
||||
const TypePrefix type_prefix(reader.Keys(), reader);
|
||||
Type deduced_weight = Type::kUnknown;
|
||||
if (type_prefix.HasPrefixes()) {
|
||||
deduced_weight = type_prefix.DeduceWeightType();
|
||||
type_prefix.PrintTypeBytes();
|
||||
}
|
||||
|
||||
// Always deduce so we can verify it against the config we read.
|
||||
const size_t layers = DeduceNumLayers(reader.Keys());
|
||||
const int layer_types = DeduceLayerTypes(reader.Keys());
|
||||
const Model deduced_model = DeduceModel(layers, layer_types);
|
||||
|
||||
ModelConfig config;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
if (reader.Find(kConfigName)) {
|
||||
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
|
||||
kConfigName, [&config](const SerializedSpan serialized) {
|
||||
const IFields::ReadResult result = config.Read(serialized, 0);
|
||||
WarnIfExtra(result, kConfigName);
|
||||
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
|
||||
}));
|
||||
|
||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
|
||||
// We trust the deserialized config, but checking helps to validate the
|
||||
// deduction, which we rely on below for pre-2025 files.
|
||||
if (config.model != deduced_model) {
|
||||
const std::string suffix = WrappingSuffix(config.wrapping);
|
||||
HWY_WARN("Detected model %s does not match config %s.",
|
||||
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
|
||||
(std::string(ModelPrefix(config.model)) + suffix).c_str());
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
|
||||
return ModelConfig(deduced_model, deduced_weight,
|
||||
ChooseWrapping(config.model, wrapping_override));
|
||||
}
|
||||
|
||||
static std::vector<float> ReadScales(BlobReader2& reader,
|
||||
const ModelConfig& config) {
|
||||
std::vector<float> scales;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
|
||||
// optional even in pre-2025 format; Griffin was the first to include it.
|
||||
if (reader.Find(kDecoratedScalesName)) {
|
||||
HWY_ASSERT(reader.CallWithSpan<float>(
|
||||
kDecoratedScalesName,
|
||||
[&scales](const hwy::Span<const float> scales_blob) {
|
||||
scales.assign(scales_blob.cbegin(), scales_blob.cend());
|
||||
}));
|
||||
}
|
||||
return scales;
|
||||
}
|
||||
|
||||
// Single-file format: reads `MatPtr` from the blob; returns false if not found.
|
||||
bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
if (!reader.Find(kMatPtrsName)) return false;
|
||||
|
||||
// For verifying `config_.weight`.
|
||||
size_t min_bits = ~size_t{0};
|
||||
Type weight_type = Type::kUnknown;
|
||||
|
||||
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
|
||||
kMatPtrsName, [&, this](SerializedSpan serialized) {
|
||||
for (size_t pos = 0; pos < serialized.size();) {
|
||||
MatPtr mat;
|
||||
const IFields::ReadResult result = mat.Read(serialized, pos);
|
||||
WarnIfExtra(result, mat.Name());
|
||||
if (result.pos == 0) {
|
||||
HWY_ABORT("Deserializing MatPtr %s failed (pos %zu of %zu).",
|
||||
mat.Name(), pos, serialized.size());
|
||||
}
|
||||
pos = result.pos + result.extra_u32;
|
||||
|
||||
// Retrieve actual key index because a writer may have written other
|
||||
// blobs before the tensor data.
|
||||
const BlobRange2* range = reader.Find(mat.Name());
|
||||
HWY_ASSERT(range);
|
||||
const size_t key_idx = range->key_idx;
|
||||
AddMatPtr(key_idx, mat);
|
||||
|
||||
const size_t bits = TypeBits(mat.GetType());
|
||||
if (bits < min_bits) {
|
||||
min_bits = bits;
|
||||
weight_type = mat.GetType();
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
HWY_ASSERT(weight_type != Type::kUnknown);
|
||||
HWY_ASSERT(weight_type == config_.weight);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
|
||||
void ModelStore2::CreateMatPtrs(BlobReader2& reader) {
|
||||
const TensorInfoRegistry tensors(config_);
|
||||
|
||||
const KeyVec& keys = reader.Keys();
|
||||
mat_ptrs_.reserve(keys.size());
|
||||
// `key_idx` is the blob index. It is not the same as the index of the
|
||||
// `MatPtr` in `mat_ptrs_` because not all blobs are tensors.
|
||||
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
|
||||
const Type type = TypePrefix::TypeFromChar(keys[key_idx][0]);
|
||||
if (type == Type::kUnknown) continue; // likely not a tensor
|
||||
|
||||
// Strip type prefix from the key. Still includes layer suffix.
|
||||
const std::string name = keys[key_idx].substr(1);
|
||||
const TensorInfo* info = tensors.Find(name);
|
||||
if (HWY_UNLIKELY(!info)) {
|
||||
if (name == "scales") continue; // ignore, not a tensor.
|
||||
HWY_ABORT("Unknown tensor %s.", name.c_str());
|
||||
}
|
||||
// Unable to set scale already because they are ordered according to
|
||||
// `ForEachTensor`, which we do not know here. The initial value is 1.0f
|
||||
// and we set the correct value in `FindAndUpdateMatPtr`.
|
||||
AddMatPtr(key_idx, MatPtr(name.c_str(), type, ExtentsFromInfo(info)));
|
||||
}
|
||||
HWY_ASSERT(mat_ptrs_.size() <= keys.size());
|
||||
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
|
||||
}
|
||||
|
||||
ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path,
|
||||
Tristate wrapping)
|
||||
: config_(ReadOrDeduceConfig(reader, wrapping)),
|
||||
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
|
||||
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
|
||||
CreateMatPtrs(reader);
|
||||
scales_ = ReadScales(reader, config_);
|
||||
// ModelConfig serialized a vector of strings. Unpack into a set for more
|
||||
// efficient lookup.
|
||||
for (const std::string& name : config_.scale_base_names) {
|
||||
scale_base_names_.insert(name);
|
||||
}
|
||||
// If the model has scales, the config must know about it.
|
||||
HWY_ASSERT(scales_.empty() || !scale_base_names_.empty());
|
||||
}
|
||||
|
||||
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
|
||||
}
|
||||
|
||||
ModelStore2::~ModelStore2() {
|
||||
// Sanity check: ensure all scales were consumed.
|
||||
HWY_ASSERT(scales_consumed_ == scales_.size());
|
||||
}
|
||||
|
||||
const MatPtr* ModelStore2::FindMat(const char* name) const {
|
||||
auto it = mat_idx_for_name_.find(name);
|
||||
if (it == mat_idx_for_name_.end()) return nullptr;
|
||||
const size_t mat_idx = it->second;
|
||||
const MatPtr* file_mat = &mat_ptrs_[mat_idx];
|
||||
HWY_ASSERT(!strcmp(file_mat->Name(), name));
|
||||
return file_mat;
|
||||
}
|
||||
|
||||
bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
|
||||
const MatPtr* file_mat = FindMat(mat.Name());
|
||||
if (!file_mat) return false;
|
||||
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
|
||||
HWY_ABORT("Tensor %s shape %zu %zu mismatches file %zu %zu.", mat.Name(),
|
||||
mat.Rows(), mat.Cols(), file_mat->Rows(), file_mat->Cols());
|
||||
}
|
||||
// `Compress()` output is always packed because it assumes a 1D array.
|
||||
HWY_ASSERT(mat.IsPacked());
|
||||
// Update fields. Name already matched, otherwise we would not find it.
|
||||
mat.SetType(file_mat->GetType());
|
||||
if (scales_.empty()) {
|
||||
// `file_mat->Scale()` is either read from file, or we have pre-2025 format
|
||||
// without the optional scales, and it is default-initialized to 1.0f.
|
||||
mat.SetScale(file_mat->Scale());
|
||||
} else { // Pre-2025 with scaling factors: set next if `mat` wants one.
|
||||
if (scale_base_names_.find(StripLayerSuffix(mat.Name())) !=
|
||||
scale_base_names_.end()) {
|
||||
HWY_ASSERT(scales_consumed_ < scales_.size());
|
||||
mat.SetScale(scales_[scales_consumed_++]);
|
||||
}
|
||||
}
|
||||
|
||||
key_idx = key_idx_[file_mat - mat_ptrs_.data()];
|
||||
return true;
|
||||
}
|
||||
|
||||
static void AddBlob(const char* name, const std::vector<uint32_t>& data,
|
||||
BlobWriter2& writer) {
|
||||
HWY_ASSERT(!data.empty());
|
||||
writer.Add(name, data.data(), data.size() * sizeof(data[0]));
|
||||
}
|
||||
|
||||
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
|
||||
const std::vector<uint32_t>& serialized_mat_ptrs,
|
||||
BlobWriter2& writer, hwy::ThreadPool& pool,
|
||||
const Path& path) {
|
||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||
const std::vector<uint32_t> serialized_config = config.Write();
|
||||
AddBlob(kConfigName, serialized_config, writer);
|
||||
|
||||
const std::string serialized_tokenizer = tokenizer.Serialize();
|
||||
HWY_ASSERT(!serialized_tokenizer.empty());
|
||||
writer.Add(kTokenizerName, serialized_tokenizer.data(),
|
||||
serialized_tokenizer.size());
|
||||
|
||||
AddBlob(kMatPtrsName, serialized_mat_ptrs, writer);
|
||||
|
||||
writer.WriteAll(pool, path);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Reads/writes model metadata (all but the weights) from/to a `BlobStore`.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/mat.h" // MatPtr
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
|
||||
// except the tensor data, which are read/written by `weights.cc`.
|
||||
//
|
||||
// As of 2025-04, the `BlobStore` format includes blobs for `ModelConfig`,
|
||||
// tokenizer, and all `MatPtr` metadata. "Pre-2025" format instead stored the
|
||||
// tokenizer in a separate file, encoded tensor type in a prefix of the blob
|
||||
// name, and had a blob for tensor scaling factors. We still support reading
|
||||
// both, but only write single-file format.
|
||||
class ModelStore2 {
|
||||
public:
|
||||
// Reads from file(s) or aborts on error. The latter two arguments are only
|
||||
// used for pre-2025 files.
|
||||
ModelStore2(BlobReader2& reader, const Path& tokenizer_path = Path(),
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
// For optimize_test.cc.
|
||||
ModelStore2(const ModelConfig& config, GemmaTokenizer&& tokenizer)
|
||||
: config_(config), tokenizer_(std::move(tokenizer)) {}
|
||||
~ModelStore2();
|
||||
|
||||
const ModelConfig& Config() const {
|
||||
HWY_ASSERT(config_.model != Model::UNKNOWN);
|
||||
return config_;
|
||||
}
|
||||
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
|
||||
// Returns nullptr if `name` is not available for loading, otherwise the
|
||||
// metadata of that tensor.
|
||||
const MatPtr* FindMat(const char* name) const;
|
||||
|
||||
// Returns false if `mat` is not available for loading, otherwise updates
|
||||
// `mat` with metadata from the file and sets `key_idx` for use by
|
||||
// `BlobReader2`. Called via `ReadOrAllocate` in `weights.cc`.
|
||||
bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const;
|
||||
|
||||
private:
|
||||
void AddMatPtr(const size_t key_idx, const MatPtr& mat) {
|
||||
auto pair_ib = mat_idx_for_name_.insert({mat.Name(), mat_ptrs_.size()});
|
||||
HWY_ASSERT_M(pair_ib.second, mat.Name()); // Ensure inserted/unique.
|
||||
mat_ptrs_.push_back(mat);
|
||||
key_idx_.push_back(key_idx);
|
||||
}
|
||||
|
||||
bool ReadMatPtrs(BlobReader2& reader);
|
||||
void CreateMatPtrs(BlobReader2& reader); // Aborts on error.
|
||||
|
||||
ModelConfig config_;
|
||||
GemmaTokenizer tokenizer_;
|
||||
|
||||
// All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`.
|
||||
std::vector<MatPtr> mat_ptrs_;
|
||||
// For each of `mat_ptrs_`, the index within `BlobReader2::Keys()`. This is
|
||||
// not necessarily iota because some blobs are not tensors, and callers may
|
||||
// have added blobs before ours.
|
||||
std::vector<size_t> key_idx_;
|
||||
// Index within `mat_ptrs_` and `key_idx_` for each tensor name.
|
||||
std::unordered_map<std::string, size_t> mat_idx_for_name_;
|
||||
|
||||
// Only used if `!ReadMatPtrs` (pre-2025 format):
|
||||
std::vector<float> scales_;
|
||||
std::unordered_set<std::string> scale_base_names_;
|
||||
mutable size_t scales_consumed_ = 0;
|
||||
};
|
||||
|
||||
// Adds metadata blobs to `writer` and writes everything to `path`. This
|
||||
// produces a single BlobStore file holding everything required for inference.
|
||||
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
|
||||
const std::vector<uint32_t>& serialized_mat_ptrs,
|
||||
BlobWriter2& writer, hwy::ThreadPool& pool,
|
||||
const Path& path);
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
|
||||
79
gemma/run.cc
79
gemma/run.cc
|
|
@ -25,16 +25,15 @@
|
|||
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
|
||||
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
|
||||
#error "Please update to version 1.2 of github.com/google/highway."
|
||||
|
|
@ -91,7 +90,7 @@ std::string GetPrompt(const InferenceArgs& inference) {
|
|||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
Gemma& model, KVCache& kv_cache) {
|
||||
const Gemma& gemma, KVCache& kv_cache) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
|
|
@ -104,22 +103,22 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
if (have_image) {
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
size_t pool_dim = gemma.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens =
|
||||
ImageTokens(model.Env().ctx.allocator,
|
||||
Extents2D(model.GetModelConfig().vit_config.seq_len /
|
||||
ImageTokens(gemma.Env().ctx.allocator,
|
||||
Extents2D(gemma.GetModelConfig().vit_config.seq_len /
|
||||
(pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
gemma.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
gemma.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
gemma.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
if (inference.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
|
|
@ -139,14 +138,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::cerr << "." << std::flush;
|
||||
}
|
||||
return true;
|
||||
} else if (model.GetModelConfig().IsEOS(token)) {
|
||||
} else if (gemma.GetModelConfig().IsEOS(token)) {
|
||||
if (inference.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (inference.verbosity >= 1) {
|
||||
|
|
@ -191,9 +190,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
size_t prompt_size = 0;
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
abs_pos, prompt_string, image_tokens.BatchSize());
|
||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
gemma.GetModelConfig().wrapping, abs_pos,
|
||||
prompt_string, image_tokens.BatchSize());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
|
|
@ -203,8 +202,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
// REMOVED: Don't change prefill_tbatch_size for image handling
|
||||
// runtime_config.prefill_tbatch_size = prompt_size;
|
||||
} else {
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.Info(), abs_pos, prompt_string);
|
||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
gemma.GetModelConfig().wrapping, abs_pos,
|
||||
prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
|
|
@ -218,7 +218,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
}
|
||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!inference.multiturn ||
|
||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(inference, gen);
|
||||
} else {
|
||||
|
|
@ -247,17 +247,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
}
|
||||
|
||||
void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference) {
|
||||
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
// Note that num_threads is an upper bound; we also limit to the number of
|
||||
// detected and enabled cores.
|
||||
MatMulEnv env(MakeMatMulEnv(threading));
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
Gemma model = CreateGemma(loader, env);
|
||||
const Gemma gemma(loader, env);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
KVCache::Create(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
|
|
@ -284,12 +282,12 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
|||
if (inference.prompt.empty()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(threading, loader, inference);
|
||||
ShowConfig(loader, threading, inference, gemma.GetModelConfig());
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
ReplGemma(threading, inference, model, kv_cache);
|
||||
ReplGemma(threading, inference, gemma, kv_cache);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -298,30 +296,17 @@ int main(int argc, char** argv) {
|
|||
{
|
||||
PROFILER_ZONE("Startup.misc");
|
||||
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
gcpp::ShowHelp(loader, threading, inference);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(threading, loader, inference);
|
||||
gcpp::Run(loader, threading, inference);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1,608 +0,0 @@
|
|||
#include "gemma/tensor_index.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Returns the non-layer tensors for the model.
|
||||
std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||
return {
|
||||
TensorInfo{
|
||||
.name = "c_embedding",
|
||||
.source_names = {"embedder/input_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {config.vocab_size, config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "c_final_norm",
|
||||
.source_names = {"final_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "enc_norm_bias",
|
||||
.source_names = {"img/Transformer/encoder_norm/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "enc_norm_scale",
|
||||
.source_names = {"img/Transformer/encoder_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_emb_bias",
|
||||
.source_names = {"img/embedding/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_emb_kernel",
|
||||
.source_names = {"img/embedding/kernel"},
|
||||
.axes = {3, 0, 1, 2},
|
||||
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
|
||||
config.vit_config.patch_width, 3},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_bias",
|
||||
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_kernel",
|
||||
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_pos_emb",
|
||||
.source_names = {"img/pos_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {/*1,*/ config.vit_config.seq_len,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
// RMS norm applied to soft tokens prior to pos embedding.
|
||||
TensorInfo{
|
||||
.name = "mm_embed_norm",
|
||||
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given image layer config.
|
||||
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const int img_layer_idx) {
|
||||
return {
|
||||
// Vit layers.
|
||||
TensorInfo{
|
||||
.name = "attn_out_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.vit_config.model_dim, layer_config.heads,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "attn_out_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads + layer_config.kv_heads * 2,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_0_w",
|
||||
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_0_b",
|
||||
.source_names = {"MlpBlock_0/Dense_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_1_w",
|
||||
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_1_b",
|
||||
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given LLM layer config.
|
||||
std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
bool reshape_att) {
|
||||
std::vector<TensorInfo> tensors = {
|
||||
TensorInfo{
|
||||
.name = "key_norm",
|
||||
.source_names = {"attn/_key_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "query_norm",
|
||||
.source_names = {"attn/_query_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv1_w",
|
||||
.source_names = {"attn/q_einsum/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {"qkv_ein", "qkv2_w"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv2_w",
|
||||
.source_names = {"attn/kv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein",
|
||||
.source_names = {"attention_block/proj_q/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.model_dim, layer_config.model_dim},
|
||||
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein",
|
||||
.source_names = {"attention_block/proj_k/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein",
|
||||
.source_names = {"attention_block/proj_v/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein",
|
||||
.source_names = {"attn/qkv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
|
||||
layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "attn_ob",
|
||||
.source_names = {"attention_block/proj_final/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
// Griffin layers.
|
||||
TensorInfo{
|
||||
.name = "gr_lin_x_w",
|
||||
.source_names = {"recurrent_block/linear_x/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_x_b",
|
||||
.source_names = {"recurrent_block/linear_x/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_y_w",
|
||||
.source_names = {"recurrent_block/linear_y/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_y_b",
|
||||
.source_names = {"recurrent_block/linear_y/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_out_w",
|
||||
.source_names = {"recurrent_block/linear_out/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_out_b",
|
||||
.source_names = {"recurrent_block/linear_out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_conv_w",
|
||||
.source_names = {"recurrent_block/conv_1d/w"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_conv_b",
|
||||
.source_names = {"recurrent_block/conv_1d/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr1_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {"gr_gate_w", "gr2_gate_w"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr2_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {2 * layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr1_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {"gr_gate_b", "gr2_gate_b"},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr2_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0, 1},
|
||||
.shape = {2 * layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_a",
|
||||
.source_names = {"recurrent_block/rg_lru/a_param"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
.scaled_softplus = true,
|
||||
},
|
||||
|
||||
TensorInfo{
|
||||
.name = "gating_ein",
|
||||
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
|
||||
"mlp_block/ffw_up/w"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gating1_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gating2_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_w",
|
||||
.source_names = {"mlp/linear/w", "mlp/linear",
|
||||
"mlp_block/ffw_down/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, layer_config.ff_hidden_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "pre_att_ns",
|
||||
.source_names = {"pre_attention_norm/scale",
|
||||
"temporal_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "pre_ff_ns",
|
||||
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "post_att_ns",
|
||||
.source_names = {"post_attention_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "post_ff_ns",
|
||||
.source_names = {"post_ffw_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ffw_gat_b",
|
||||
.source_names = {"mlp_block/ffw_up/b"},
|
||||
.axes = {0},
|
||||
.shape = {2 * layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ffw_out_b",
|
||||
.source_names = {"mlp_block/ffw_down/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
};
|
||||
if (reshape_att) {
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_w",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_ein",
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
} else {
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_ein",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_w",
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||
int img_layer_idx, bool reshape_att)
|
||||
: config_(config),
|
||||
llm_layer_idx_(llm_layer_idx),
|
||||
img_layer_idx_(img_layer_idx) {
|
||||
int layer_idx = std::max(llm_layer_idx_, img_layer_idx_);
|
||||
std::string suffix;
|
||||
if (layer_idx >= 0) {
|
||||
suffix = "_" + std::to_string(layer_idx);
|
||||
}
|
||||
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
||||
tensors_ = ModelTensors(config);
|
||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||
img_layer_idx <
|
||||
static_cast<int>(config.vit_config.layer_configs.size())) {
|
||||
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
||||
} else if (0 <= llm_layer_idx &&
|
||||
llm_layer_idx < static_cast<int>(config.layer_configs.size())) {
|
||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
|
||||
}
|
||||
for (size_t i = 0; i < tensors_.size(); ++i) {
|
||||
std::string key = tensors_[i].name + suffix;
|
||||
name_map_.insert({key, i});
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfo TensorIndex::TensorInfoFromSourcePath(
|
||||
const std::string& path) const {
|
||||
for (const auto& tensor : tensors_) {
|
||||
for (const auto& source_name : tensor.source_names) {
|
||||
auto pos = path.rfind(source_name);
|
||||
if (pos != std::string::npos && path.size() == pos + source_name.size())
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
return TensorInfo();
|
||||
}
|
||||
|
||||
const TensorInfo* TensorIndex::FindName(const std::string& name) const {
|
||||
std::string name_to_find = name;
|
||||
if (!std::isdigit(name[name.size() - 1])) {
|
||||
if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) {
|
||||
name_to_find = name + "_" + std::to_string(img_layer_idx_);
|
||||
} else if (llm_layer_idx_ >= 0) {
|
||||
name_to_find = name + "_" + std::to_string(llm_layer_idx_);
|
||||
}
|
||||
}
|
||||
auto it = name_map_.find(name_to_find);
|
||||
if (it == name_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &tensors_[it->second];
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
#include "gemma/tensor_index.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Tests that each tensor in the model can be found by exactly one TensorIndex,
|
||||
// and that the TensorIndex returns the correct shape and name for the tensor,
|
||||
// for all models.
|
||||
TEST(TensorIndexTest, FindName) {
|
||||
for (Model model : kAllModels) {
|
||||
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
|
||||
ModelConfig config = ConfigFromModel(model);
|
||||
std::vector<TensorIndex> tensor_indexes;
|
||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||
/*img_layer_idx=*/-1,
|
||||
/*split_and_reshape=*/false);
|
||||
for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size();
|
||||
++llm_layer_idx) {
|
||||
tensor_indexes.emplace_back(config, static_cast<int>(llm_layer_idx),
|
||||
/*img_layer_idx=*/-1,
|
||||
/*split_and_reshape=*/false);
|
||||
}
|
||||
for (size_t img_layer_idx = 0;
|
||||
img_layer_idx < config.vit_config.layer_configs.size();
|
||||
++img_layer_idx) {
|
||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||
static_cast<int>(img_layer_idx),
|
||||
/*split_and_reshape=*/false);
|
||||
}
|
||||
// For each tensor in any model, exactly one TensorIndex should find it.
|
||||
ModelWeightsPtrs<SfpStream> weights(config);
|
||||
ModelWeightsPtrs<SfpStream>::ForEachTensor(
|
||||
{&weights}, ForEachType::kInitNoToc,
|
||||
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
int num_found = 0;
|
||||
const MatPtr& tensor = *tensors[0];
|
||||
for (const auto& tensor_index : tensor_indexes) {
|
||||
// Skip the type marker prefix, but we want the layer index suffix.
|
||||
std::string name_to_find(name + 1, strlen(name) - 1);
|
||||
const TensorInfo* info = tensor_index.FindName(name_to_find);
|
||||
if (info != nullptr) {
|
||||
// Test that the MatPtr can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
|
||||
EXPECT_STREQ(tensor.Name(), mat_ptr.Name())
|
||||
<< "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
|
||||
++num_found;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(num_found, 1) << " for tensor " << name;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,592 @@
|
|||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void TensorInfoRegistry::Add(const std::string& suffix,
|
||||
const TensorInfo& info) {
|
||||
const size_t idx = tensors_.size();
|
||||
tensors_.push_back(info);
|
||||
// Also add suffix to `concat_names`.
|
||||
for (std::string& name : tensors_.back().concat_names) {
|
||||
name += suffix;
|
||||
}
|
||||
|
||||
const std::string name = info.base_name + suffix;
|
||||
// Ensure successful insertion because `suffix` ensures uniqueness for
|
||||
// per-layer tensors, and per-model should only be inserted once.
|
||||
HWY_ASSERT_M(idx_from_name_.insert({name, idx}).second, name.c_str());
|
||||
}
|
||||
|
||||
// Non-layer tensors.
|
||||
void TensorInfoRegistry::AddModelTensors(const ModelConfig& config) {
|
||||
const std::string no_suffix;
|
||||
Add(no_suffix, {
|
||||
.base_name = "c_embedding",
|
||||
.source_names = {"embedder/input_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {config.vocab_size, config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "c_final_norm",
|
||||
.source_names = {"final_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "enc_norm_bias",
|
||||
.source_names = {"img/Transformer/encoder_norm/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "enc_norm_scale",
|
||||
.source_names = {"img/Transformer/encoder_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "img_emb_bias",
|
||||
.source_names = {"img/embedding/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(no_suffix,
|
||||
{
|
||||
.base_name = "img_emb_kernel",
|
||||
.source_names = {"img/embedding/kernel"},
|
||||
.axes = {3, 0, 1, 2},
|
||||
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
|
||||
config.vit_config.patch_width, 3},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
Add(no_suffix,
|
||||
{
|
||||
.base_name = "img_head_bias",
|
||||
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(no_suffix,
|
||||
{
|
||||
.base_name = "img_head_kernel",
|
||||
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "img_pos_emb",
|
||||
.source_names = {"img/pos_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {/*1,*/ config.vit_config.seq_len,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
// RMS norm applied to soft tokens prior to pos embedding.
|
||||
Add(no_suffix, {
|
||||
.base_name = "mm_embed_norm",
|
||||
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
}
|
||||
|
||||
// Returns the tensors for the given image layer config.
|
||||
void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const size_t img_layer_idx) {
|
||||
const std::string suffix = LayerSuffix(img_layer_idx);
|
||||
|
||||
// Vit layers.
|
||||
Add(suffix, {
|
||||
.base_name = "attn_out_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.vit_config.model_dim, layer_config.heads,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "attn_out_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "q_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "k_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "v_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "q_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "k_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "v_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads + layer_config.kv_heads * 2,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "linear_0_w",
|
||||
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "linear_0_b",
|
||||
.source_names = {"MlpBlock_0/Dense_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "linear_1_w",
|
||||
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "linear_1_b",
|
||||
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_0_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_0_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_1_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_1_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
}
|
||||
|
||||
void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config,
|
||||
const size_t layer_idx) {
|
||||
const std::string suffix = LayerSuffix(layer_idx);
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_x_w",
|
||||
.source_names = {"recurrent_block/linear_x/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_x_b",
|
||||
.source_names = {"recurrent_block/linear_x/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_y_w",
|
||||
.source_names = {"recurrent_block/linear_y/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_y_b",
|
||||
.source_names = {"recurrent_block/linear_y/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_out_w",
|
||||
.source_names = {"recurrent_block/linear_out/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_out_b",
|
||||
.source_names = {"recurrent_block/linear_out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "gr_conv_w",
|
||||
.source_names = {"recurrent_block/conv_1d/w"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_conv_b",
|
||||
.source_names = {"recurrent_block/conv_1d/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr1_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {"gr_gate_w", "gr2_gate_w"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr2_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {2 * layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr1_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {"gr_gate_b", "gr2_gate_b"},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr2_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0, 1},
|
||||
.shape = {2 * layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_a",
|
||||
.source_names = {"recurrent_block/rg_lru/a_param"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
.scaled_softplus = true,
|
||||
});
|
||||
}
|
||||
|
||||
void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const size_t layer_idx) {
|
||||
const std::string suffix = LayerSuffix(layer_idx);
|
||||
Add(suffix, {
|
||||
.base_name = "key_norm",
|
||||
.source_names = {"attn/_key_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "query_norm",
|
||||
.source_names = {"attn/_query_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv1_w",
|
||||
.source_names = {"attn/q_einsum/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {"qkv_ein", "qkv2_w"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv2_w",
|
||||
.source_names = {"attn/kv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "q_ein",
|
||||
.source_names = {"attention_block/proj_q/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.model_dim, layer_config.model_dim},
|
||||
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "k_ein",
|
||||
.source_names = {"attention_block/proj_k/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "v_ein",
|
||||
.source_names = {"attention_block/proj_v/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv_ein",
|
||||
.source_names = {"attn/qkv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
|
||||
layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "attn_ob",
|
||||
.source_names = {"attention_block/proj_final/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
|
||||
Add(suffix, {
|
||||
.base_name = "gating_ein",
|
||||
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
|
||||
"mlp_block/ffw_up/w"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gating1_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gating2_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "linear_w",
|
||||
.source_names = {"mlp/linear/w", "mlp/linear",
|
||||
"mlp_block/ffw_down/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, layer_config.ff_hidden_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "pre_att_ns",
|
||||
.source_names = {"pre_attention_norm/scale",
|
||||
"temporal_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "pre_ff_ns",
|
||||
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "post_att_ns",
|
||||
.source_names = {"post_attention_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "post_ff_ns",
|
||||
.source_names = {"post_ffw_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "ffw_gat_b",
|
||||
.source_names = {"mlp_block/ffw_up/b"},
|
||||
.axes = {0},
|
||||
.shape = {2 * layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "ffw_out_b",
|
||||
.source_names = {"mlp_block/ffw_down/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "att_ein",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "att_w",
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
|
||||
if (config.model == Model::GRIFFIN_2B) {
|
||||
AddGriffinLayerTensors(layer_config, layer_idx);
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) {
|
||||
// Upper bound on the number of `Add()` calls in `Add*Tensors()`. Loose bound
|
||||
// in case those are changed without updating this. Better to allocate a bit
|
||||
// more than to 1.5-2x the size if too little.
|
||||
tensors_.reserve(10 + 32 * config.layer_configs.size() +
|
||||
24 * config.vit_config.layer_configs.size());
|
||||
AddModelTensors(config);
|
||||
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
|
||||
AddLayerTensors(config, config.layer_configs[i], i);
|
||||
}
|
||||
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
|
||||
AddImageLayerTensors(config, config.vit_config.layer_configs[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfo TensorInfoRegistry::TensorInfoFromSourcePath(const std::string& path,
|
||||
int layer_idx) const {
|
||||
for (const TensorInfo& tensor : tensors_) {
|
||||
for (const std::string& source_name : tensor.source_names) {
|
||||
// path ends with source_name?
|
||||
const size_t pos = path.rfind(source_name);
|
||||
if (pos != std::string::npos && path.size() == pos + source_name.size()) {
|
||||
std::string name = tensor.base_name;
|
||||
if (layer_idx >= 0) name += LayerSuffix(static_cast<size_t>(layer_idx));
|
||||
return TensorInfoFromName(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return TensorInfo();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
|
|
@ -7,17 +7,18 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/configs.h"
|
||||
#include "util/basics.h" // Extents2D
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Universal tensor information. Holds enough information to construct a
|
||||
// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the
|
||||
// tensor from the python model with necessary transpose/reshape info.
|
||||
// Tensor metadata. This is far more than required to construct the `MatPtr` in
|
||||
// `LayerWeightsPtrs/WeightsPtrs`; they only use `.shape` via `ExtentsFromInfo`.
|
||||
// This is also bound to Python and filled by the exporter.
|
||||
struct TensorInfo {
|
||||
// The name of the tensor in the sbs file
|
||||
std::string name;
|
||||
// The base name of the tensor without a layer suffix.
|
||||
std::string base_name;
|
||||
// Strings to match to the end of the name of the tensor in the python model.
|
||||
std::vector<std::string> source_names;
|
||||
// Initial reshape shape. Use only as a last resort when input may have
|
||||
|
|
@ -42,7 +43,7 @@ struct TensorInfo {
|
|||
std::vector<std::string> concat_names;
|
||||
// Axis at which to concatenate.
|
||||
size_t concat_axis = 0;
|
||||
// The minimum compression weight type for this tensor. The default is
|
||||
// The highest permissible compression for this tensor. The default is
|
||||
// kNUQ, which provides maximum compression. Other values such as kBF16
|
||||
// or kF32 can be used to limit the compression to a specific type.
|
||||
Type min_size = Type::kNUQ;
|
||||
|
|
@ -55,7 +56,8 @@ struct TensorInfo {
|
|||
};
|
||||
|
||||
// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for
|
||||
// not-present tensors such as ViT in a text-only model.
|
||||
// not-present tensors such as ViT in a text-only model. Safely handles nullptr
|
||||
// returned from `TensorInfoRegistry::Find`, hence not a member function.
|
||||
static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
|
||||
if (tensor == nullptr) return Extents2D(0, 0);
|
||||
|
||||
|
|
@ -76,58 +78,64 @@ static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
|
|||
return Extents2D(rows, cols);
|
||||
}
|
||||
|
||||
// Universal index of tensor information, which can be built for a specific
|
||||
// layer_idx.
|
||||
class TensorIndex {
|
||||
static inline std::string LayerSuffix(size_t layer_idx) {
|
||||
return std::string("_") + std::to_string(layer_idx);
|
||||
}
|
||||
|
||||
// Returns tensor base name without any layer suffix.
|
||||
static inline std::string StripLayerSuffix(const std::string& name) {
|
||||
return name.substr(0, name.rfind('_'));
|
||||
}
|
||||
|
||||
// Holds all `TensorInfo` for a model and retrieves them by (unique) name.
|
||||
class TensorInfoRegistry {
|
||||
public:
|
||||
// Builds a list of TensorInfo for the given layer_idx.
|
||||
// If reshape_att is true, the attn_vec_einsum tensor is reshaped.
|
||||
TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx,
|
||||
bool reshape_att);
|
||||
~TensorIndex() = default;
|
||||
explicit TensorInfoRegistry(const ModelConfig& config);
|
||||
~TensorInfoRegistry() = default;
|
||||
|
||||
// Returns the TensorInfo whose source_name matches the end of the given path,
|
||||
// or an empty TensorInfo if not found.
|
||||
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||
TensorInfo TensorInfoFromSourcePath(const std::string& path) const;
|
||||
// Returns nullptr if not found, otherwise the `TensorInfo` for the given
|
||||
// `name`, which either lacks a suffix, or is per-layer and ends with
|
||||
// `LayerSuffix(layer_idx)`. Used in `WeightsPtrs/LayerWeightsPtrs`.
|
||||
const TensorInfo* Find(const std::string& name) const {
|
||||
auto it = idx_from_name_.find(name);
|
||||
if (it == idx_from_name_.end()) return nullptr;
|
||||
return &tensors_[it->second];
|
||||
}
|
||||
|
||||
// Returns the TensorInfo whose name matches the given name,
|
||||
// or an empty TensorInfo if not found.
|
||||
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||
// Returns a copy of the `TensorInfo` whose name matches the given name, or a
|
||||
// default-constructed `TensorInfo` if not found. Destroying
|
||||
// `TensorInfoRegistry` afterward will not invalidate the returned value.
|
||||
TensorInfo TensorInfoFromName(const std::string& name) const {
|
||||
const TensorInfo* info = FindName(name);
|
||||
const TensorInfo* info = Find(name);
|
||||
if (info == nullptr) return TensorInfo();
|
||||
return *info;
|
||||
}
|
||||
|
||||
// Returns the TensorInfo for the given tensor name, for concise construction
|
||||
// of ModelWeightsPtrs/LayerWeightsPtrs.
|
||||
const TensorInfo* FindName(const std::string& name) const;
|
||||
// Returns a copy of the `TensorInfo` whose source_name matches the end of the
|
||||
// given path, and whose name ends with the given layer_idx, otherwise a
|
||||
// default-constructed `TensorInfo`. Destroying `TensorInfoRegistry`
|
||||
// afterward will not invalidate the returned value.
|
||||
TensorInfo TensorInfoFromSourcePath(const std::string& path,
|
||||
int layer_idx) const;
|
||||
|
||||
private:
|
||||
// Config that was used to build the tensor index.
|
||||
const ModelConfig& config_;
|
||||
// Layer that this tensor index is for - either LLM or image.
|
||||
int llm_layer_idx_;
|
||||
int img_layer_idx_;
|
||||
// List of tensor information for this layer.
|
||||
// `suffix` is empty (only) for per-model tensors, otherwise `LayerSuffix`.
|
||||
void Add(const std::string& suffix, const TensorInfo& info);
|
||||
void AddModelTensors(const ModelConfig& config);
|
||||
void AddLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config, size_t layer_idx);
|
||||
void AddGriffinLayerTensors(const LayerConfig& layer_config,
|
||||
size_t layer_idx);
|
||||
|
||||
void AddImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
size_t img_layer_idx);
|
||||
|
||||
std::vector<TensorInfo> tensors_;
|
||||
// Map from tensor name to index in tensors_.
|
||||
std::unordered_map<std::string, size_t> name_map_;
|
||||
// Includes entries for base name *and* the suffixed name for each layer.
|
||||
std::unordered_map<std::string, size_t> idx_from_name_;
|
||||
};
|
||||
|
||||
static inline TensorIndex TensorIndexLLM(const ModelConfig& config,
|
||||
size_t llm_layer_idx) {
|
||||
return TensorIndex(config, static_cast<int>(llm_layer_idx), -1, false);
|
||||
}
|
||||
|
||||
static inline TensorIndex TensorIndexImg(const ModelConfig& config,
|
||||
size_t img_layer_idx) {
|
||||
return TensorIndex(config, -1, static_cast<int>(img_layer_idx), false);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/shared.h" // SfpStream
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT_M
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Tests for all models that each tensor in the model can be found and that the
|
||||
// TensorInfoRegistry returns the correct shape and name for the tensor.
|
||||
TEST(TensorInfoRegistryTest, Find) {
|
||||
ForEachModel([&](Model model) {
|
||||
const ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
|
||||
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
|
||||
config.Specifier().c_str());
|
||||
const TensorInfoRegistry tensors(config);
|
||||
// Each tensor in the model should be known/found.
|
||||
ModelWeightsPtrs<SfpStream> weights(config);
|
||||
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
|
||||
const TensorInfo* info = tensors.Find(t.mat.Name());
|
||||
HWY_ASSERT_M(info, t.mat.Name());
|
||||
// Test that the `MatPtr` can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
MatPtrT<SfpStream> mat_ptr(t.mat.Name(), tensors);
|
||||
EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name();
|
||||
EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name();
|
||||
EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -21,9 +21,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "gemma/configs.h" // PromptWrapping
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/profiler.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
|
|
@ -37,24 +35,20 @@ constexpr bool kShowTokenization = false;
|
|||
class GemmaTokenizer::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
explicit Impl(const Path& tokenizer_path) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->Load(tokenizer_path.path).ok()) {
|
||||
HWY_ABORT("Failed to load the tokenizer file.");
|
||||
}
|
||||
}
|
||||
// Loads the tokenizer from a serialized proto.
|
||||
explicit Impl(const std::string& tokenizer_proto) {
|
||||
if (tokenizer_proto == kMockTokenizer) return;
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) {
|
||||
fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size());
|
||||
HWY_ABORT("Failed to load the tokenizer from serialized proto.");
|
||||
HWY_ABORT("Failed to load tokenizer from %zu byte serialized proto.",
|
||||
tokenizer_proto.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::string Serialize() const { return spp_->serialized_model_proto(); }
|
||||
std::string Serialize() const {
|
||||
return spp_ ? spp_->serialized_model_proto() : kMockTokenizer;
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
|
|
@ -82,41 +76,38 @@ class GemmaTokenizer::Impl {
|
|||
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
|
||||
};
|
||||
|
||||
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||
GemmaTokenizer::GemmaTokenizer(const std::string& tokenizer_proto)
|
||||
: impl_(std::make_unique<Impl>(tokenizer_proto)) {
|
||||
HWY_ASSERT(impl_);
|
||||
}
|
||||
|
||||
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
|
||||
GemmaTokenizer::GemmaTokenizer() = default;
|
||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||
|
||||
std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); }
|
||||
|
||||
void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_proto);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return impl_ && impl_->Encode(input, pieces);
|
||||
return impl_->Encode(input, pieces);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<int>* ids) const {
|
||||
return impl_ && impl_->Encode(input, ids);
|
||||
return impl_->Encode(input, ids);
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const {
|
||||
return impl_ && impl_->Decode(ids, detokenized);
|
||||
return impl_->Decode(ids, detokenized);
|
||||
}
|
||||
|
||||
bool GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer, Model model) {
|
||||
GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer,
|
||||
Model model) {
|
||||
sot_user_.reserve(3);
|
||||
if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return false;
|
||||
if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return;
|
||||
sot_model_.reserve(3);
|
||||
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
|
||||
eot_.reserve(2);
|
||||
|
|
@ -127,7 +118,6 @@ bool GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer, Model model) {
|
|||
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
|
||||
vlm_eoi_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
||||
|
|
@ -182,12 +172,12 @@ std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
|
|||
// Text
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
const PromptWrapping wrapping, size_t pos,
|
||||
const std::string& prompt) {
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||
|
||||
switch (info.wrapping) {
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::GEMMA_IT:
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return chat_template.Apply(pos, tokens);
|
||||
|
|
@ -202,12 +192,12 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
// Vision
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
const PromptWrapping wrapping, size_t pos,
|
||||
const std::string& prompt,
|
||||
size_t image_batch_size) {
|
||||
std::vector<int> text_part;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
|
||||
switch (info.wrapping) {
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::PALIGEMMA:
|
||||
HWY_ASSERT(pos == 0);
|
||||
return chat_template.WrapPali(text_part, image_batch_size);
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // ModelInfo
|
||||
#include "gemma/configs.h" // PromptWrapping
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -32,19 +31,24 @@ constexpr int EOS_ID = 1;
|
|||
constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3
|
||||
constexpr int BOS_ID = 2;
|
||||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
GemmaTokenizer();
|
||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||
// To avoid the complexity of storing the tokenizer into testdata/ or
|
||||
// downloading from gs://, while still always writing a blob for the tokenizer,
|
||||
// but also avoiding empty blobs, we store this placeholder string.
|
||||
constexpr const char* kMockTokenizer = "unavailable";
|
||||
|
||||
// must come after definition of Impl
|
||||
class GemmaTokenizer {
|
||||
// These must be defined after the definition of `Impl`.
|
||||
public:
|
||||
// If unavailable, pass `kMockTokenizer`.
|
||||
explicit GemmaTokenizer(const std::string& tokenizer_proto);
|
||||
~GemmaTokenizer();
|
||||
GemmaTokenizer(GemmaTokenizer&& other);
|
||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||
|
||||
// Returns `kMockTokenizer` if unavailable.
|
||||
std::string Serialize() const;
|
||||
void Deserialize(const std::string& tokenizer_proto);
|
||||
|
||||
// Returns false on failure or if unavailable.
|
||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||
|
|
@ -56,13 +60,9 @@ class GemmaTokenizer {
|
|||
|
||||
class GemmaChatTemplate {
|
||||
public:
|
||||
GemmaChatTemplate() = default;
|
||||
explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model) {
|
||||
(void)Init(tokenizer, model);
|
||||
}
|
||||
|
||||
// Returns false if the tokenizer is not available (as in optimize_test.cc).
|
||||
bool Init(const GemmaTokenizer& tokenizer, Model model);
|
||||
// No effect if `tokenizer` is unavailable (as happens in optimize_test.cc),
|
||||
// but then any other method may abort.
|
||||
GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model);
|
||||
|
||||
// Given prompt tokens, this returns the wrapped prompt including BOS and
|
||||
// any "start_of_turn" structure required by the model.
|
||||
|
|
@ -83,12 +83,12 @@ class GemmaChatTemplate {
|
|||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
PromptWrapping wrapping, size_t pos,
|
||||
const std::string& prompt);
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
PromptWrapping wrapping, size_t pos,
|
||||
const std::string& prompt,
|
||||
size_t image_batch_size);
|
||||
|
||||
|
|
|
|||
436
gemma/weights.cc
436
gemma/weights.cc
|
|
@ -15,7 +15,10 @@
|
|||
|
||||
#include "gemma/weights.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
|
@ -23,264 +26,44 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/compress-inl.h"
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/stats.h"
|
||||
|
||||
// TODO: move into foreach_target; this is only used for NUQ Reshape.
|
||||
#include "compression/compress-inl.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T>
|
||||
struct TensorLoader {
|
||||
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
|
||||
ReadFromBlobStore& loader) {
|
||||
weights.ForEachTensor(
|
||||
{&weights}, fet,
|
||||
[&loader](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
loader(name, tensors);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||
Type weight_type, PromptWrapping wrapping,
|
||||
hwy::ThreadPool& pool,
|
||||
std::string* tokenizer_proto) {
|
||||
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
|
||||
if (!weights.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights.path.c_str());
|
||||
}
|
||||
ReadFromBlobStore loader(weights);
|
||||
ForEachType fet =
|
||||
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
|
||||
std::vector<float> scales;
|
||||
if (fet == ForEachType::kLoadWithToc) {
|
||||
BlobError err = loader.LoadConfig(config_);
|
||||
if (err != 0 || config_.model_dim == 0) {
|
||||
fprintf(stderr, "Failed to load model config: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
if (tokenizer_proto != nullptr) {
|
||||
err = loader.LoadTokenizer(*tokenizer_proto);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to load tokenizer: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (weight_type == Type::kUnknown || model_type == Model::UNKNOWN) {
|
||||
fprintf(stderr,
|
||||
"weight type (%d) and model type (%d) must be specified when "
|
||||
"no config is present in weights file\n",
|
||||
static_cast<int>(weight_type), static_cast<int>(model_type));
|
||||
return __LINE__;
|
||||
}
|
||||
// No Toc-> no config.
|
||||
config_ = ConfigFromModel(model_type);
|
||||
config_.weight = weight_type;
|
||||
config_.wrapping = wrapping;
|
||||
scales.resize(config_.num_tensor_scales + config_.vit_config.num_scales);
|
||||
}
|
||||
CreateForType(config_.weight, pool);
|
||||
CallForModelWeightT<TensorLoader>(fet, loader);
|
||||
if (!scales.empty()) {
|
||||
loader.LoadScales(scales.data(), scales.size());
|
||||
}
|
||||
BlobError err = loader.ReadAll(pool, model_storage_);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to load model weights: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
if (!scales.empty()) {
|
||||
GetOrApplyScales(scales);
|
||||
}
|
||||
if (fet == ForEachType::kLoadNoToc) {
|
||||
PROFILER_ZONE("Startup.Reshape");
|
||||
AllocAndCopyWithTranspose(pool);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TensorSaver {
|
||||
// Adds all the tensors to the blob writer.
|
||||
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
|
||||
WriteToBlobStore& writer) {
|
||||
weights.ForEachTensor(
|
||||
{&weights}, fet,
|
||||
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
CallUpcasted(tensors[0]->GetType(), tensors[0], writer, name);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
BlobError ModelWeightsStorage::Save(const std::string& tokenizer,
|
||||
const Path& weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
WriteToBlobStore writer(pool);
|
||||
ForEachType fet = ForEachType::kLoadWithToc;
|
||||
CallForModelWeightT<TensorSaver>(fet, writer);
|
||||
writer.AddTokenizer(tokenizer);
|
||||
int err = writer.WriteAll(weights, &config_);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write model weights: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");
|
||||
config_ = config;
|
||||
config_.weight = weight_type;
|
||||
CreateForType(weight_type, pool);
|
||||
if (float_weights_) float_weights_->Allocate(model_storage_, pool);
|
||||
if (bf16_weights_) bf16_weights_->Allocate(model_storage_, pool);
|
||||
if (sfp_weights_) sfp_weights_->Allocate(model_storage_, pool);
|
||||
if (nuq_weights_) nuq_weights_->Allocate(model_storage_, pool);
|
||||
}
|
||||
|
||||
class WeightInitializer {
|
||||
public:
|
||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
float* data = tensors[0]->RowT<float>(0);
|
||||
for (size_t i = 0; i < tensors[0]->Extents().Area(); ++i) {
|
||||
data[i] = dist_(gen_);
|
||||
}
|
||||
tensors[0]->SetScale(1.0f);
|
||||
}
|
||||
|
||||
private:
|
||||
std::normal_distribution<float> dist_;
|
||||
std::mt19937& gen_;
|
||||
};
|
||||
|
||||
void ModelWeightsStorage::RandInit(std::mt19937& gen) {
|
||||
HWY_ASSERT(float_weights_);
|
||||
WeightInitializer init(gen);
|
||||
ModelWeightsPtrs<float>::ForEachTensor({float_weights_.get()},
|
||||
ForEachType::kLoadNoToc, init);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::ZeroInit() {
|
||||
if (float_weights_) float_weights_->ZeroInit();
|
||||
if (bf16_weights_) bf16_weights_->ZeroInit();
|
||||
if (sfp_weights_) sfp_weights_->ZeroInit();
|
||||
if (nuq_weights_) nuq_weights_->ZeroInit();
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::GetOrApplyScales(std::vector<float>& scales) {
|
||||
if (float_weights_) float_weights_->GetOrApplyScales(scales);
|
||||
if (bf16_weights_) bf16_weights_->GetOrApplyScales(scales);
|
||||
if (sfp_weights_) sfp_weights_->GetOrApplyScales(scales);
|
||||
if (nuq_weights_) nuq_weights_->GetOrApplyScales(scales);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::AllocAndCopyWithTranspose(hwy::ThreadPool& pool) {
|
||||
if (float_weights_)
|
||||
float_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
if (bf16_weights_)
|
||||
bf16_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
if (sfp_weights_)
|
||||
sfp_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
if (nuq_weights_)
|
||||
nuq_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::CopyWithTranspose(hwy::ThreadPool& pool) {
|
||||
if (float_weights_) float_weights_->CopyWithTranspose(pool);
|
||||
if (bf16_weights_) bf16_weights_->CopyWithTranspose(pool);
|
||||
if (sfp_weights_) sfp_weights_->CopyWithTranspose(pool);
|
||||
if (nuq_weights_) nuq_weights_->CopyWithTranspose(pool);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void LogVec(const char* name, const float* data, size_t len) {
|
||||
hwy::Stats stats;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
stats.Notify(data[i]);
|
||||
}
|
||||
printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
|
||||
name, len, stats.Min(), stats.Mean(), stats.Max());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ModelWeightsStorage::LogWeightStats() {
|
||||
size_t total_weights = 0;
|
||||
// Only for float weights.
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{float_weights_.get()}, ForEachType::kInitNoToc,
|
||||
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
const MatPtr& tensor = *tensors[0];
|
||||
if (tensor.Scale() != 1.0f) {
|
||||
printf("[scale=%f] ", tensor.Scale());
|
||||
}
|
||||
LogVec(name, tensor.RowT<float>(0), tensor.Extents().Area());
|
||||
total_weights += tensor.Extents().Area();
|
||||
});
|
||||
printf("%-20s %12zu\n", "Total", total_weights);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::CreateForType(Type weight_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
sfp_weights_ =
|
||||
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
nuq_weights_ =
|
||||
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void LayerWeightsPtrs<NuqStream>::Reshape(MatOwner* storage) {
|
||||
void LayerWeightsPtrs<NuqStream>::Reshape() {
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
|
||||
|
||||
HWY_ASSERT(att_weights.HasPtr());
|
||||
HWY_ASSERT(att_weights.GetType() == Type::kNUQ);
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||
if (storage != nullptr) {
|
||||
storage->AllocateFor(att_weights, MatPadding::kPacked);
|
||||
}
|
||||
|
||||
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
|
||||
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
|
||||
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||
|
||||
HWY_NAMESPACE::DecompressAndZeroPad(
|
||||
df, MakeSpan(attn_vec_einsum_w.Packed(), model_dim * heads * qkv_dim), 0,
|
||||
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
|
||||
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||
HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0,
|
||||
attn_vec_einsum_w_tmp.get(),
|
||||
model_dim * heads * qkv_dim);
|
||||
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
|
||||
|
|
@ -293,13 +76,186 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatOwner* storage) {
|
|||
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
|
||||
HWY_NAMESPACE::Compress(
|
||||
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
|
||||
MakeSpan(att_weights.Packed(), model_dim * heads * qkv_dim),
|
||||
/*packed_ofs=*/0, pool);
|
||||
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
|
||||
work, att_weights.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
||||
// Aborts on error.
|
||||
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
|
||||
const std::vector<BlobRange2>& ranges,
|
||||
MatOwners& mat_owners, const MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(mats.size() == ranges.size());
|
||||
|
||||
if (reader.IsMapped()) {
|
||||
PROFILER_ZONE("Startup.Weights.Map");
|
||||
for (size_t i = 0; i < mats.size(); ++i) {
|
||||
// SetPtr does not change the stride, but it is expected to be packed
|
||||
// because that is what Compress() writes to the file.
|
||||
const size_t mat_bytes = mats[i]->PackedBytes();
|
||||
// Ensure blob size matches that computed from metadata.
|
||||
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
|
||||
|
||||
hwy::Span<const uint8_t> span = reader.MappedSpan<uint8_t>(ranges[i]);
|
||||
HWY_ASSERT(span.size() == mat_bytes);
|
||||
mats[i]->SetPtr(const_cast<uint8_t*>(span.data()), mats[i]->Stride());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
PROFILER_ZONE("Startup.Weights.AllocateAndEnqueue");
|
||||
|
||||
// NOTE: this changes the stride of `mats`!
|
||||
mat_owners.AllocateFor(mats, padding, pool);
|
||||
|
||||
// Enqueue the read requests, one per row in each tensor.
|
||||
for (size_t i = 0; i < mats.size(); ++i) {
|
||||
uint64_t offset = ranges[i].offset;
|
||||
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
|
||||
// Caution, `RowT` requires knowledge of the actual type. We instead use
|
||||
// the first row, which is the same for any type, and advance the *byte*
|
||||
// pointer by the *byte* stride.
|
||||
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
|
||||
uint8_t* row = mats[i]->RowT<uint8_t>(0);
|
||||
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
|
||||
reader.Enqueue(BlobRange2{.offset = offset,
|
||||
.bytes = file_bytes_per_row,
|
||||
.key_idx = ranges[i].key_idx},
|
||||
row);
|
||||
offset += file_bytes_per_row;
|
||||
row += mem_stride_bytes;
|
||||
// Keep the in-memory row padding uninitialized so msan detects any use.
|
||||
}
|
||||
}
|
||||
|
||||
reader.ReadAll(pool);
|
||||
}
|
||||
|
||||
void WeightsOwner::ReadOrAllocate(const ModelStore2& model, BlobReader2& reader,
|
||||
hwy::ThreadPool& pool) {
|
||||
// List of tensors to read/map, and where from.
|
||||
std::vector<MatPtr*> mats;
|
||||
std::vector<BlobRange2> ranges;
|
||||
|
||||
// Padding is inserted when reading row by row, except for NUQ tensors.
|
||||
const MatPadding padding = MatPadding::kOdd;
|
||||
|
||||
AllocatePointer(model.Config());
|
||||
|
||||
// Enumerate all weights (negligible cost).
|
||||
CallT([&](const auto& weights) {
|
||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
if (t.flags & TensorArgs::kOnlyAllocate) {
|
||||
mat_owners_.AllocateFor(t.mat, padding);
|
||||
return;
|
||||
}
|
||||
size_t key_idx;
|
||||
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
||||
mats.push_back(&t.mat);
|
||||
ranges.push_back(reader.Range(key_idx));
|
||||
return;
|
||||
}
|
||||
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
|
||||
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
||||
});
|
||||
});
|
||||
|
||||
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
|
||||
|
||||
Reshape(pool);
|
||||
}
|
||||
|
||||
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
||||
// of `CallT` because that is const, hence it would pass a const& of the
|
||||
// `unique_ptr` to its lambda, but we want to reset the pointer.
|
||||
void WeightsOwner::AllocatePointer(const ModelConfig& config) {
|
||||
switch (weight_type_) {
|
||||
case Type::kSFP:
|
||||
sfp_weights_.reset(new ModelWeightsPtrs<SfpStream>(config));
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
nuq_weights_.reset(new ModelWeightsPtrs<NuqStream>(config));
|
||||
break;
|
||||
case Type::kF32:
|
||||
float_weights_.reset(new ModelWeightsPtrs<float>(config));
|
||||
break;
|
||||
case Type::kBF16:
|
||||
bf16_weights_.reset(new ModelWeightsPtrs<BF16>(config));
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
|
||||
}
|
||||
}
|
||||
|
||||
// Gemma calls `WeightsOwner::ReadOrAllocate`, but test code instead calls
|
||||
// `WeightsPtrs::AllocateForTest`, so the implementation is there, and here
|
||||
// we only type-dispatch.
|
||||
void WeightsOwner::AllocateForTest(const ModelConfig& config,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.AllocateWeights");
|
||||
|
||||
AllocatePointer(config);
|
||||
CallT([&](const auto& weights) {
|
||||
weights->AllocateForTest(mat_owners_, pool);
|
||||
});
|
||||
}
|
||||
|
||||
void WeightsOwner::ZeroInit() {
|
||||
PROFILER_FUNC;
|
||||
CallT([](const auto& weights) { weights->ZeroInit(); });
|
||||
}
|
||||
|
||||
void WeightsOwner::RandInit(float stddev, std::mt19937& gen) {
|
||||
PROFILER_FUNC;
|
||||
float_weights_->RandInit(stddev, gen);
|
||||
}
|
||||
|
||||
void WeightsOwner::LogWeightStatsF32() {
|
||||
size_t total_weights = 0;
|
||||
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
|
||||
float_weights_->ForEachTensor(
|
||||
nullptr, nullptr, [&total_weights](const TensorArgs& t) {
|
||||
if (t.mat.Scale() != 1.0f) {
|
||||
printf("[scale=%f] ", t.mat.Scale());
|
||||
}
|
||||
hwy::Stats stats;
|
||||
HWY_ASSERT(t.mat.GetType() == Type::kF32);
|
||||
for (size_t r = 0; r < t.mat.Rows(); ++r) {
|
||||
const float* HWY_RESTRICT row = t.mat.RowT<float>(r);
|
||||
for (size_t c = 0; c < t.mat.Cols(); ++c) {
|
||||
stats.Notify(row[c]);
|
||||
}
|
||||
}
|
||||
printf("%-20s %12zu %13.10f %8.5f %13.10f\n", t.mat.Name(),
|
||||
t.mat.Rows() * t.mat.Cols(), stats.Min(), stats.Mean(),
|
||||
stats.Max());
|
||||
|
||||
total_weights += t.mat.Rows() * t.mat.Cols();
|
||||
});
|
||||
printf("%-20s %12zu\n", "Total", total_weights);
|
||||
}
|
||||
|
||||
void WeightsOwner::Reshape(hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.Reshape");
|
||||
CallT([&pool](const auto& weights) { weights->Reshape(pool); });
|
||||
}
|
||||
|
||||
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
|
||||
BlobWriter2& writer) const {
|
||||
std::vector<uint32_t> serialized_mat_ptrs;
|
||||
CallT([&](const auto& weights) {
|
||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
if (t.flags & TensorArgs::kOnlyAllocate) return;
|
||||
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
||||
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
||||
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
|
||||
t.mat.AppendTo(serialized_mat_ptrs);
|
||||
});
|
||||
});
|
||||
return serialized_mat_ptrs;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
921
gemma/weights.h
921
gemma/weights.h
File diff suppressed because it is too large
Load Diff
|
|
@ -372,14 +372,6 @@ HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) {
|
|||
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
|
||||
}
|
||||
|
||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||
template <typename MatT, typename VT>
|
||||
HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs,
|
||||
const VT* vec_aligned, size_t num) {
|
||||
const hn::ScalableTag<VT> d;
|
||||
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
|||
}
|
||||
});
|
||||
|
||||
CompressScaled(raw_mat.get(), extents.Area(), ws, *mat, pool);
|
||||
Compress(raw_mat.get(), extents.Area(), ws, mat->Span(), 0, pool);
|
||||
mat->SetScale(1.9f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -804,7 +804,6 @@ class MMScaleDemoteAdd {
|
|||
// We manually unroll 2x for higher IPC in batch=1.
|
||||
size_t col_c = range_nc.begin();
|
||||
if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) {
|
||||
HWY_UNROLL(1)
|
||||
for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) {
|
||||
VD a0, a1; // unused if !kAdd
|
||||
if constexpr (kAdd) {
|
||||
|
|
|
|||
|
|
@ -700,6 +700,10 @@ struct ConstMat {
|
|||
|
||||
const Extents2D& Extents() const { return extents; }
|
||||
size_t Stride() const { return stride; }
|
||||
float Scale() const { return scale; }
|
||||
// So that matvec-inl.h can use the same interface as MatPtrT:
|
||||
size_t Rows() const { return extents.rows; }
|
||||
size_t Cols() const { return extents.cols; }
|
||||
|
||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
||||
// subrange of the original rows starting at row 0.
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@
|
|||
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/mat.h" // MatPtrT
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
|
||||
|
|
@ -45,14 +47,26 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||
template <class ArrayT, typename VT>
|
||||
HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned,
|
||||
// Adapter so that gemma-inl.h can pass ConstMat.
|
||||
// TODO: remove after changing ComputeQKV to MatMul.
|
||||
template <typename WT, typename VT>
|
||||
HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
||||
size_t num) {
|
||||
const hn::ScalableTag<VT> d;
|
||||
HWY_DASSERT(num <= w.Stride()); // Single row, else padding is an issue.
|
||||
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
|
||||
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
|
||||
}
|
||||
// For callers that pass `MatPtrT`.
|
||||
template <typename WT, typename VT>
|
||||
HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
||||
size_t num) {
|
||||
const hn::ScalableTag<VT> d;
|
||||
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// ArrayT is either MatPtrT or ConstMat.
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||
// always with addition.
|
||||
template <typename ArrayT, typename VecT, typename AddT>
|
||||
|
|
@ -67,8 +81,8 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
PROFILER_ZONE("TwoOfsMatVecAddLoop");
|
||||
|
||||
for (size_t idx_row = 0; idx_row < outer; ++idx_row) {
|
||||
const size_t row_ofs0 = mat_ofs0 + (idx_row)*inner;
|
||||
const size_t row_ofs1 = mat_ofs1 + (idx_row)*inner;
|
||||
const size_t row_ofs0 = mat_ofs0 + idx_row * mat.Stride();
|
||||
const size_t row_ofs1 = mat_ofs1 + idx_row * mat.Stride();
|
||||
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
|
||||
Dot(mat, row_ofs0, vec_aligned, inner);
|
||||
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
|
||||
|
|
@ -107,11 +121,11 @@ namespace detail {
|
|||
// coordinate of the tile is r0, c0.
|
||||
template <class DF, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void AccumulatePartialDotProducts(
|
||||
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
||||
size_t c0, size_t num_rows, size_t num_cols,
|
||||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||
DF df, const ArrayT& mat, size_t mat_ofs, size_t r0, size_t c0,
|
||||
size_t num_rows, size_t num_cols, const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat.Stride();
|
||||
out[idx_row] += Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
|
|
@ -121,14 +135,13 @@ HWY_INLINE void AccumulatePartialDotProducts(
|
|||
// accumulate.
|
||||
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
|
||||
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||
size_t mat_ofs, size_t mat_stride,
|
||||
size_t r0, size_t c0,
|
||||
size_t mat_ofs, size_t r0, size_t c0,
|
||||
size_t num_rows, size_t num_cols,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const InitT* HWY_RESTRICT init,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat.Stride();
|
||||
if constexpr (kInit) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||
Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
|
|
@ -144,32 +157,32 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
|||
// store into in out[r - r0].
|
||||
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
|
||||
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||
size_t mat_ofs, size_t mat_stride,
|
||||
size_t r0, size_t num_rows,
|
||||
size_t mat_ofs, size_t r0,
|
||||
size_t num_rows, size_t num_cols,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT out) {
|
||||
HWY_DASSERT(num_cols <= mat.Cols());
|
||||
// Tall and skinny: set `out` to the single dot product.
|
||||
if (mat_stride < MaxCols()) {
|
||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||
num_rows, mat_stride, vec_aligned, add,
|
||||
out);
|
||||
if (num_cols < MaxCols()) {
|
||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, r0, 0, num_rows,
|
||||
num_cols, vec_aligned, add, out);
|
||||
return;
|
||||
}
|
||||
|
||||
// We have at least MaxCols, so start by setting `out` to that:
|
||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||
num_rows, MaxCols(), vec_aligned, add, out);
|
||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, r0, 0, num_rows, MaxCols(),
|
||||
vec_aligned, add, out);
|
||||
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
||||
size_t c0 = MaxCols();
|
||||
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||
MaxCols(), vec_aligned, out);
|
||||
for (; c0 <= num_cols - MaxCols(); c0 += MaxCols()) {
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, r0, c0, num_rows, MaxCols(),
|
||||
vec_aligned, out);
|
||||
}
|
||||
|
||||
if (c0 < mat_stride) { // Final cols
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||
mat_stride - c0, vec_aligned, out);
|
||||
if (c0 < num_cols) { // Final cols
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, r0, c0, num_rows,
|
||||
num_cols - c0, vec_aligned, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -193,9 +206,8 @@ HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
|
|||
pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * rows_per_strip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, inner, r0,
|
||||
rows_per_strip, vec_aligned, add,
|
||||
out + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, r0, rows_per_strip,
|
||||
inner, vec_aligned, add, out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
|
|
@ -203,7 +215,7 @@ HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
|
|||
if (r0 < outer) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = outer - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, inner, r0, num_rows,
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, r0, num_rows, inner,
|
||||
vec_aligned, add, out + r0);
|
||||
}
|
||||
}
|
||||
|
|
@ -249,12 +261,10 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
|
|||
pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("TwoMatVec.lambda");
|
||||
const size_t r0 = strip * rows_per_strip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, inner, r0,
|
||||
rows_per_strip, vec_aligned, add0,
|
||||
out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, inner, r0,
|
||||
rows_per_strip, vec_aligned, add1,
|
||||
out1 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, r0, rows_per_strip,
|
||||
inner, vec_aligned, add0, out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, r0, rows_per_strip,
|
||||
inner, vec_aligned, add1, out1 + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
|
|
@ -262,10 +272,10 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
|
|||
if (r0 < outer) {
|
||||
PROFILER_ZONE("TwoMatVec remainder");
|
||||
const size_t num_rows = outer - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(
|
||||
df, mat0, mat_ofs, inner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(
|
||||
df, mat1, mat_ofs, inner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, r0, num_rows,
|
||||
inner, vec_aligned, add0, out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, r0, num_rows,
|
||||
inner, vec_aligned, add1, out1 + r0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -388,7 +388,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
void TestRopeAndMulBy() {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
|
||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv));
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ cc_test(
|
|||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:allocator",
|
||||
"//:benchmark_helper",
|
||||
"//:common",
|
||||
"//:configs",
|
||||
"//:gemma_lib",
|
||||
"//compression:shared",
|
||||
"@highway//:hwy",
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ TEST_F(PaliGemmaTest, General) {
|
|||
break;
|
||||
default:
|
||||
FAIL() << "Unsupported model: "
|
||||
<< s_env->GetGemma()->GetModelConfig().model_name;
|
||||
<< s_env->GetGemma()->GetModelConfig().display_name;
|
||||
break;
|
||||
}
|
||||
TestQuestions(qa, num);
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
# [internal2] load py_binary
|
||||
# [internal] load py_binary
|
||||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
||||
|
||||
|
|
@ -12,7 +13,8 @@ pybind_extension(
|
|||
name = "configs",
|
||||
srcs = ["configs.cc"],
|
||||
deps = [
|
||||
"//:common",
|
||||
"//:configs",
|
||||
"//:tensor_info",
|
||||
"//compression:shared",
|
||||
],
|
||||
)
|
||||
|
|
@ -25,7 +27,6 @@ pybind_extension(
|
|||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//:threading_context",
|
||||
"//compression:shared",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@
|
|||
#include <pybind11/stl.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/tensor_info.h"
|
||||
|
||||
using gcpp::ActivationType;
|
||||
using gcpp::InternalLayerConfig;
|
||||
using gcpp::InternalModelConfig;
|
||||
using gcpp::LayerAttentionType;
|
||||
using gcpp::LayerConfig;
|
||||
using gcpp::Model;
|
||||
|
|
@ -32,8 +34,8 @@ using gcpp::PostQKType;
|
|||
using gcpp::PromptWrapping;
|
||||
using gcpp::QueryScaleType;
|
||||
using gcpp::ResidualType;
|
||||
using gcpp::TensorIndex;
|
||||
using gcpp::TensorInfo;
|
||||
using gcpp::TensorInfoRegistry;
|
||||
using gcpp::Type;
|
||||
using gcpp::VitConfig;
|
||||
|
||||
|
|
@ -99,7 +101,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
|
||||
class_<TensorInfo>(py_module, "TensorInfo")
|
||||
.def(init())
|
||||
.def_readwrite("name", &TensorInfo::name)
|
||||
.def_readwrite("name", &TensorInfo::base_name)
|
||||
.def_readwrite("source_names", &TensorInfo::source_names)
|
||||
.def_readwrite("preshape", &TensorInfo::preshape)
|
||||
.def_readwrite("axes", &TensorInfo::axes)
|
||||
|
|
@ -110,13 +112,17 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus)
|
||||
.def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims);
|
||||
|
||||
class_<TensorIndex>(py_module, "TensorIndex")
|
||||
.def(init<const ModelConfig&, int, int, bool>())
|
||||
class_<TensorInfoRegistry>(py_module, "TensorInfoRegistry")
|
||||
.def(init<const ModelConfig&>())
|
||||
.def("tensor_info_from_source_path",
|
||||
&TensorIndex::TensorInfoFromSourcePath, arg("path"))
|
||||
.def("tensor_info_from_name", &TensorIndex::TensorInfoFromName,
|
||||
&TensorInfoRegistry::TensorInfoFromSourcePath, arg("path"),
|
||||
arg("layer_idx"))
|
||||
.def("tensor_info_from_name", &TensorInfoRegistry::TensorInfoFromName,
|
||||
arg("name"));
|
||||
|
||||
class_<InternalLayerConfig>(py_module, "InternalLayerConfig")
|
||||
.def(init<>());
|
||||
|
||||
class_<LayerConfig>(py_module, "LayerConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_dim", &LayerConfig::model_dim)
|
||||
|
|
@ -133,7 +139,9 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("post_norm", &LayerConfig::post_norm)
|
||||
.def_readwrite("type", &LayerConfig::type)
|
||||
.def_readwrite("activation", &LayerConfig::activation)
|
||||
.def_readwrite("post_qk", &LayerConfig::post_qk);
|
||||
.def_readwrite("post_qk", &LayerConfig::post_qk)
|
||||
.def_readwrite("use_qk_norm", &LayerConfig::use_qk_norm)
|
||||
.def_readwrite("internal", &LayerConfig::internal);
|
||||
|
||||
class_<VitConfig>(py_module, "VitConfig")
|
||||
.def(init())
|
||||
|
|
@ -144,10 +152,15 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("image_size", &VitConfig::image_size)
|
||||
.def_readwrite("layer_configs", &VitConfig::layer_configs);
|
||||
|
||||
class_<InternalModelConfig>(py_module, "InternalModelConfig")
|
||||
.def(init<>());
|
||||
|
||||
class_<ModelConfig>(py_module, "ModelConfig")
|
||||
.def(init())
|
||||
.def(init<>())
|
||||
.def(init<Model, Type, PromptWrapping>())
|
||||
.def(init<const char*>())
|
||||
.def_readwrite("model_family_version", &ModelConfig::model_family_version)
|
||||
.def_readwrite("model_name", &ModelConfig::model_name)
|
||||
.def_readwrite("display_name", &ModelConfig::display_name)
|
||||
.def_readwrite("model", &ModelConfig::model)
|
||||
.def_readwrite("wrapping", &ModelConfig::wrapping)
|
||||
.def_readwrite("weight", &ModelConfig::weight)
|
||||
|
|
@ -155,7 +168,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
||||
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
||||
.def_readwrite("seq_len", &ModelConfig::seq_len)
|
||||
.def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales)
|
||||
// Skip `unused_num_tensor_scales`.
|
||||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
||||
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
|
||||
|
|
@ -164,22 +177,24 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
|
||||
.def_readwrite("attention_window_sizes",
|
||||
&ModelConfig::attention_window_sizes)
|
||||
.def_readwrite("scale_names", &ModelConfig::scale_names)
|
||||
.def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups)
|
||||
.def_readwrite("vit_config", &ModelConfig::vit_config)
|
||||
.def_readwrite("pool_dim", &ModelConfig::pool_dim)
|
||||
.def_readwrite("eos_id", &ModelConfig::eos_id)
|
||||
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
|
||||
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
|
||||
.def_readwrite("internal", &ModelConfig::internal)
|
||||
|
||||
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
||||
arg("layer_config"))
|
||||
.def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"),
|
||||
arg("debug"));
|
||||
|
||||
// Returns the config for the given model.
|
||||
py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model"));
|
||||
|
||||
// Returns the model for the given config, if it matches any standard model.
|
||||
py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config"));
|
||||
.def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("print"))
|
||||
.def("overwrite_with_canonical", &ModelConfig::OverwriteWithCanonical)
|
||||
.def("specifier", &ModelConfig::Specifier);
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
py_module.def("vit_config", &gcpp::GetVitConfig, arg("config"));
|
||||
|
||||
py_module.def("is_paligemma", &gcpp::IsPaliGemma, arg("model"));
|
||||
}
|
||||
|
||||
} // namespace pybind11
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ import safetensors
|
|||
import torch
|
||||
|
||||
from compression.python import compression
|
||||
from python import configs
|
||||
|
||||
|
||||
def flatten_f32(x: np.ndarray) -> np.ndarray:
|
||||
|
|
@ -70,14 +71,23 @@ def compute_scale(x: np.ndarray) -> float:
|
|||
|
||||
|
||||
def _is_float_param(param_name: str) -> bool:
|
||||
for prefix in ["img_pos_emb", "attn_out_b", "linear_0_b", "linear_1_b",
|
||||
"qkv_ein_b", "img_emb_bias", "img_head_bias"]:
|
||||
"""Returns whether the tensor should be stored as float32."""
|
||||
for prefix in [
|
||||
"img_pos_emb",
|
||||
"attn_out_b",
|
||||
"linear_0_b",
|
||||
"linear_1_b",
|
||||
"qkv_ein_b",
|
||||
"img_emb_bias",
|
||||
"img_head_bias",
|
||||
]:
|
||||
if param_name.startswith(prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_bf16_param(param_name: str) -> bool:
|
||||
"""Returns whether the tensor should be stored as bf16."""
|
||||
for prefix in ["pre_", "post_", "c_", "img_head_kernel"]:
|
||||
if param_name.startswith(prefix):
|
||||
return True
|
||||
|
|
@ -106,6 +116,7 @@ def _get_layer_config(dims: Dict[str, Any]):
|
|||
|
||||
Args:
|
||||
dims: A dictionary of (mostly) dimension values.
|
||||
|
||||
Returns:
|
||||
A dictionary of layer configurations.
|
||||
"""
|
||||
|
|
@ -114,45 +125,141 @@ def _get_layer_config(dims: Dict[str, Any]):
|
|||
vit_seq_len = dims["vit_seq_len"]
|
||||
config = {
|
||||
"llm-non-layers": [
|
||||
("language_model.model.embed_tokens.weight", (257152, model_dim), "c_embedding"),
|
||||
(
|
||||
"language_model.model.embed_tokens.weight",
|
||||
(257152, model_dim),
|
||||
"c_embedding",
|
||||
),
|
||||
("language_model.model.norm.weight", (model_dim,), "c_final_norm"),
|
||||
],
|
||||
"llm-layers": [
|
||||
("language_model.model.layers.%d.mlp.down_proj.weight", (model_dim, hidden_dim), "linear_w"),
|
||||
(
|
||||
"language_model.model.layers.%d.mlp.down_proj.weight",
|
||||
(model_dim, hidden_dim),
|
||||
"linear_w",
|
||||
),
|
||||
],
|
||||
"img-non-layers": [
|
||||
("vision_tower.vision_model.post_layernorm.bias", (1152,), "enc_norm_bias"),
|
||||
("vision_tower.vision_model.post_layernorm.weight", (1152,), "enc_norm_scale"),
|
||||
("vision_tower.vision_model.embeddings.patch_embedding.bias", (1152,), "img_emb_bias"),
|
||||
("vision_tower.vision_model.embeddings.patch_embedding.weight", (1152, 14, 14, 3), "img_emb_kernel"),
|
||||
(
|
||||
"vision_tower.vision_model.post_layernorm.bias",
|
||||
(1152,),
|
||||
"enc_norm_bias",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.post_layernorm.weight",
|
||||
(1152,),
|
||||
"enc_norm_scale",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
(1152,),
|
||||
"img_emb_bias",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
(1152, 14, 14, 3),
|
||||
"img_emb_kernel",
|
||||
),
|
||||
("multi_modal_projector.linear.bias", (model_dim,), "img_head_bias"),
|
||||
("multi_modal_projector.linear.weight", (model_dim, 1152), "img_head_kernel"),
|
||||
("vision_tower.vision_model.embeddings.position_embedding.weight", (vit_seq_len, 1152), "img_pos_emb"),
|
||||
(
|
||||
"multi_modal_projector.linear.weight",
|
||||
(model_dim, 1152),
|
||||
"img_head_kernel",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
(vit_seq_len, 1152),
|
||||
"img_pos_emb",
|
||||
),
|
||||
],
|
||||
"img-layers": [
|
||||
("vision_tower.vision_model.encoder.layers.%d.layer_norm1.bias", (1152,), "ln_0_bias"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.layer_norm1.weight", (1152,), "ln_0_scale"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.layer_norm2.bias", (1152,), "ln_1_bias"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.layer_norm2.weight", (1152,), "ln_1_scale"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.mlp.fc1.bias", (4304,), "linear_0_b"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.mlp.fc1.weight", (4304, 1152), "linear_0_w"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.mlp.fc2.bias", (1152,), "linear_1_b"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.mlp.fc2.weight", (1152, 4304), "linear_1_w"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.bias", (1152,), "attn_out_b"),
|
||||
("vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.weight", (1152, 16 * 72), "attn_out_w"),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.layer_norm1.bias",
|
||||
(1152,),
|
||||
"ln_0_bias",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.layer_norm1.weight",
|
||||
(1152,),
|
||||
"ln_0_scale",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.layer_norm2.bias",
|
||||
(1152,),
|
||||
"ln_1_bias",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.layer_norm2.weight",
|
||||
(1152,),
|
||||
"ln_1_scale",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.mlp.fc1.bias",
|
||||
(4304,),
|
||||
"linear_0_b",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.mlp.fc1.weight",
|
||||
(4304, 1152),
|
||||
"linear_0_w",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.mlp.fc2.bias",
|
||||
(1152,),
|
||||
"linear_1_b",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.mlp.fc2.weight",
|
||||
(1152, 4304),
|
||||
"linear_1_w",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.bias",
|
||||
(1152,),
|
||||
"attn_out_b",
|
||||
),
|
||||
(
|
||||
"vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.weight",
|
||||
(1152, 16 * 72),
|
||||
"attn_out_w",
|
||||
),
|
||||
],
|
||||
}
|
||||
if dims["has_post_norm"]: # See longer comment above.
|
||||
config["llm-layers"] += [
|
||||
("language_model.model.layers.%d.input_layernorm.weight", (model_dim,), "pre_att_ns"),
|
||||
("language_model.model.layers.%d.pre_feedforward_layernorm.weight", (model_dim,), "pre_ff_ns"),
|
||||
("language_model.model.layers.%d.post_attention_layernorm.weight", (model_dim,), "post_att_ns"),
|
||||
("language_model.model.layers.%d.post_feedforward_layernorm.weight", (model_dim,), "post_ff_ns"),
|
||||
(
|
||||
"language_model.model.layers.%d.input_layernorm.weight",
|
||||
(model_dim,),
|
||||
"pre_att_ns",
|
||||
),
|
||||
(
|
||||
"language_model.model.layers.%d.pre_feedforward_layernorm.weight",
|
||||
(model_dim,),
|
||||
"pre_ff_ns",
|
||||
),
|
||||
(
|
||||
"language_model.model.layers.%d.post_attention_layernorm.weight",
|
||||
(model_dim,),
|
||||
"post_att_ns",
|
||||
),
|
||||
(
|
||||
"language_model.model.layers.%d.post_feedforward_layernorm.weight",
|
||||
(model_dim,),
|
||||
"post_ff_ns",
|
||||
),
|
||||
]
|
||||
else:
|
||||
config["llm-layers"] += [
|
||||
("language_model.model.layers.%d.input_layernorm.weight", (model_dim,), "pre_att_ns"),
|
||||
("language_model.model.layers.%d.post_attention_layernorm.weight", (model_dim,), "pre_ff_ns"),
|
||||
(
|
||||
"language_model.model.layers.%d.input_layernorm.weight",
|
||||
(model_dim,),
|
||||
"pre_att_ns",
|
||||
),
|
||||
(
|
||||
"language_model.model.layers.%d.post_attention_layernorm.weight",
|
||||
(model_dim,),
|
||||
"pre_ff_ns",
|
||||
),
|
||||
]
|
||||
return config
|
||||
|
||||
|
|
@ -162,6 +269,7 @@ def _get_dimensions(params):
|
|||
|
||||
Args:
|
||||
params: A dictionary with parameters.
|
||||
|
||||
Returns:
|
||||
A dictionary of dimension values.
|
||||
"""
|
||||
|
|
@ -191,7 +299,9 @@ def _get_dimensions(params):
|
|||
|
||||
|
||||
def export_paligemma_sbs(
|
||||
model_specifier: str,
|
||||
load_path: str,
|
||||
tokenizer_file: str,
|
||||
csv_file: str,
|
||||
sbs_file: str,
|
||||
) -> None:
|
||||
|
|
@ -220,8 +330,7 @@ def export_paligemma_sbs(
|
|||
"language_model.model.embed_tokens.weight"
|
||||
][:-64]
|
||||
|
||||
# Initialize a few things.
|
||||
writer = compression.SbsWriter(compression.CompressorMode.NO_TOC)
|
||||
writer = compression.SbsWriter()
|
||||
metadata = []
|
||||
scales = {}
|
||||
dims = _get_dimensions(params)
|
||||
|
|
@ -255,13 +364,13 @@ def export_paligemma_sbs(
|
|||
|
||||
# Determine the type as which to insert.
|
||||
if _is_float_param(sbs_name):
|
||||
insert = writer.insert_float # Insert as float.
|
||||
packed = configs.Type.kF32
|
||||
print(f"Inserting {both_names} as float (f32) (no scaling)")
|
||||
elif _is_bf16_param(sbs_name) or param_name.startswith("vision_tower"):
|
||||
insert = writer.insert_bf16 # Insert as BF16.
|
||||
packed = configs.Type.kBF16
|
||||
print(f"Inserting {both_names} as BF16 (no scaling)")
|
||||
else:
|
||||
insert = writer.insert_sfp # Insert as SFP.
|
||||
packed = configs.Type.kSFP
|
||||
# Assumes that all scales are 1.0 for SFP. Consider adding scales.
|
||||
# They would still need to be written, but would be collected here.
|
||||
assert scale == 1.0, f"Scale for {both_names} is not 1.0"
|
||||
|
|
@ -272,7 +381,10 @@ def export_paligemma_sbs(
|
|||
sys.stdout.flush()
|
||||
|
||||
# Add the data to the writer.
|
||||
insert(sbs_name, value)
|
||||
info = configs.TensorInfo()
|
||||
info.name = sbs_name
|
||||
info.shape = data.shape
|
||||
writer.insert(sbs_name, value, packed, info)
|
||||
|
||||
def add_qkv_einsum(i): # Handle qkv for layer i.
|
||||
name = "language_model.model.layers.%d.self_attn.q_proj.weight" # (N*H, D)
|
||||
|
|
@ -367,7 +479,12 @@ def export_paligemma_sbs(
|
|||
|
||||
# Handle the image embedding kernel transpose.
|
||||
name = "vision_tower.vision_model.embeddings.patch_embedding.weight"
|
||||
assert params[name].shape == (1152, 3, 14, 14,)
|
||||
assert params[name].shape == (
|
||||
1152,
|
||||
3,
|
||||
14,
|
||||
14,
|
||||
)
|
||||
params[name] = params[name].permute(0, 2, 3, 1)
|
||||
|
||||
# Add the non-layer params.
|
||||
|
|
@ -393,18 +510,30 @@ def export_paligemma_sbs(
|
|||
assert not params, "Some params were not used: %s" % params.keys()
|
||||
|
||||
# Write everything to the sbs file.
|
||||
writer.write(sbs_file)
|
||||
assert model_specifier.startswith("paligemma")
|
||||
writer.write(configs.ModelConfig(model_specifier), tokenizer_file, sbs_file)
|
||||
|
||||
# Write the metadata for manual inspection.
|
||||
with open(csv_file, "w") as csv_handle:
|
||||
csv.writer(csv_handle).writerows(metadata)
|
||||
|
||||
|
||||
_MODEL_SPECIFIER = flags.DEFINE_string(
|
||||
"model_specifier",
|
||||
"",
|
||||
"String specifying model, size, weight, wrapping (ModelConfig.Specifier)",
|
||||
)
|
||||
|
||||
_LOAD_PATH = flags.DEFINE_string(
|
||||
"load_path",
|
||||
"",
|
||||
"Path to the safetensors index.json file to read",
|
||||
)
|
||||
_TOKENIZER_FILE = flags.DEFINE_string(
|
||||
"tokenizer_file",
|
||||
"/tmp/tokenizer.spm",
|
||||
"Path to the tokenizer file to read and embed",
|
||||
)
|
||||
_METADATA_FILE = flags.DEFINE_string(
|
||||
"metadata_file",
|
||||
"/tmp/gemmacpp.csv",
|
||||
|
|
@ -422,14 +551,22 @@ def main(argv: Sequence[str]) -> None:
|
|||
raise app.UsageError("Too many command-line arguments.")
|
||||
logging.use_python_logging()
|
||||
logging.set_verbosity(logging.INFO)
|
||||
model_specifier = _MODEL_SPECIFIER.value
|
||||
load_path = _LOAD_PATH.value
|
||||
tokenizer_file = _TOKENIZER_FILE.value
|
||||
metadata_file = _METADATA_FILE.value
|
||||
sbs_file = _SBS_FILE.value
|
||||
|
||||
logging.info(
|
||||
"\n====\nReading from %s and writing to %s\n====", load_path, sbs_file
|
||||
"\n====\nReading %s from %s and %s, writing to %s\n====",
|
||||
model_specifier,
|
||||
load_path,
|
||||
tokenizer_file,
|
||||
sbs_file,
|
||||
)
|
||||
export_paligemma_sbs(
|
||||
model_specifier, load_path, tokenizer_file, metadata_file, sbs_file
|
||||
)
|
||||
export_paligemma_sbs(load_path, metadata_file, sbs_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
|
|||
class GemmaModel {
|
||||
public:
|
||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::InferenceArgs& inference,
|
||||
const gcpp::ThreadingArgs& threading)
|
||||
: gemma_(threading, loader, inference), last_prob_(0.0f) {}
|
||||
const gcpp::ThreadingArgs& threading,
|
||||
const gcpp::InferenceArgs& inference)
|
||||
: gemma_(loader, threading, inference), last_prob_(0.0f) {}
|
||||
|
||||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
|
|
@ -167,7 +167,7 @@ class GemmaModel {
|
|||
// Generate* will use this image. Throws an error for other models.
|
||||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
gcpp::Gemma& gemma = *(gemma_.GetGemma());
|
||||
const gcpp::Gemma& gemma = *gemma_.GetGemma();
|
||||
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
|
||||
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||
|
|
@ -200,7 +200,7 @@ class GemmaModel {
|
|||
if (image_tokens_.Cols() == 0) {
|
||||
throw std::invalid_argument("No image set.");
|
||||
}
|
||||
gcpp::Gemma& model = *(gemma_.GetGemma());
|
||||
const gcpp::Gemma& model = *gemma_.GetGemma();
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
|
|
@ -258,27 +258,21 @@ class GemmaModel {
|
|||
|
||||
PYBIND11_MODULE(gemma, mod) {
|
||||
py::class_<GemmaModel>(mod, "GemmaModel")
|
||||
.def(py::init([](std::string tokenizer, std::string weights,
|
||||
std::string model, std::string weight_type,
|
||||
.def(py::init([](const std::string& tokenizer, const std::string& weights,
|
||||
size_t max_threads) {
|
||||
gcpp::LoaderArgs loader(tokenizer, weights, model);
|
||||
if (const char* err = loader.Validate()) {
|
||||
throw std::invalid_argument(err);
|
||||
}
|
||||
loader.weight_type_str = weight_type;
|
||||
const gcpp::LoaderArgs loader(tokenizer, weights);
|
||||
gcpp::ThreadingArgs threading;
|
||||
threading.max_lps = max_threads;
|
||||
gcpp::InferenceArgs inference;
|
||||
inference.max_generated_tokens = 512;
|
||||
auto gemma =
|
||||
std::make_unique<GemmaModel>(loader, inference, threading);
|
||||
std::make_unique<GemmaModel>(loader, threading, inference);
|
||||
if (!gemma->ModelIsLoaded()) {
|
||||
throw std::invalid_argument("Could not load model.");
|
||||
}
|
||||
return gemma;
|
||||
}),
|
||||
py::arg("tokenizer_path"), py::arg("weights_path"),
|
||||
py::arg("model_flag"), py::arg("weight_type") = "sfp",
|
||||
py::arg("max_threads") = 0)
|
||||
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
|
||||
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
|
|||
}
|
||||
|
||||
template <class TArgs>
|
||||
static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) {
|
||||
static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(const TArgs& args) {
|
||||
if (const char* err = args.Validate()) {
|
||||
args.Help();
|
||||
HWY_ABORT("Problem with args: %s\n", err);
|
||||
|
|
|
|||
26
util/mat.h
26
util/mat.h
|
|
@ -27,7 +27,7 @@
|
|||
// IWYU pragma: begin_exports
|
||||
#include "compression/fields.h"
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // Extents2D
|
||||
// IWYU pragma: end_exports
|
||||
|
|
@ -205,12 +205,11 @@ class MatPtrT : public MatPtr {
|
|||
// Called by `MatStorageT`.
|
||||
MatPtrT(const char* name, Extents2D extents)
|
||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||
// Take shape from `TensorInfo` to avoid duplicating it in the caller.
|
||||
MatPtrT(const char* name, const TensorInfo* tensor)
|
||||
: MatPtrT<MatT>(name, ExtentsFromInfo(tensor)) {}
|
||||
// Find `TensorInfo` by name in `TensorIndex`.
|
||||
MatPtrT(const char* name, const TensorIndex& tensor_index)
|
||||
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||
// Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. This is
|
||||
// not a factory function because `weights.h` initializes members of type
|
||||
// `MatPtrT<T>`, and `T` cannot be inferred at compile time from arguments.
|
||||
MatPtrT(const std::string& name, const TensorInfoRegistry& info)
|
||||
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
|
||||
|
||||
// Copying allowed because the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
|
|
@ -279,17 +278,8 @@ decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
|||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||
void ZeroInit(MatPtr& mat);
|
||||
|
||||
template <typename T>
|
||||
void RandInit(MatPtrT<T>& x, float stddev, std::mt19937& gen) {
|
||||
std::normal_distribution<T> dist(0.0, stddev);
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
T* row = x.Row(r);
|
||||
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||
row[c] = dist(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
// F32/F64 only.
|
||||
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
|
||||
|
||||
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
|
||||
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB.
|
||||
|
|
|
|||
Loading…
Reference in New Issue