mirror of https://github.com/google/gemma.cpp.git
Added tensor_index as a single source of truth on tensor shapes/sources and transformations
PiperOrigin-RevId: 697903886
This commit is contained in:
parent
7d685a267f
commit
73640d2521
18
BUILD.bazel
18
BUILD.bazel
|
|
@ -194,13 +194,15 @@ cc_library(
|
|||
srcs = [
|
||||
"gemma/common.cc",
|
||||
"gemma/configs.cc",
|
||||
"gemma/tensor_index.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/tensor_index.h",
|
||||
],
|
||||
deps = [
|
||||
"//compression:compress",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
@ -215,6 +217,20 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_index_test",
|
||||
srcs = ["gemma/tensor_index_test.cc"],
|
||||
deps = [
|
||||
":basics",
|
||||
":common",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ set(SOURCES
|
|||
gemma/instantiations/sfp.cc
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/tensor_index.cc
|
||||
gemma/tensor_index.h
|
||||
gemma/tokenizer.cc
|
||||
gemma/tokenizer.h
|
||||
gemma/weights.cc
|
||||
|
|
@ -157,6 +159,7 @@ set(GEMMA_TEST_FILES
|
|||
compression/nuq_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
gemma/tensor_index_test.cc
|
||||
ops/dot_test.cc
|
||||
ops/gemma_matvec_test.cc
|
||||
ops/matmul_test.cc
|
||||
|
|
|
|||
|
|
@ -201,6 +201,7 @@ cc_library(
|
|||
":sfp",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:common",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
|
||||
#include "hwy/base.h"
|
||||
#define COMPRESS_STATS 0
|
||||
|
||||
#include <stddef.h>
|
||||
|
|
@ -33,6 +34,7 @@
|
|||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "util/basics.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -211,6 +213,26 @@ class MatPtrT : public MatPtr {
|
|||
// Full constructor for dynamic sizing.
|
||||
MatPtrT(const std::string& name, size_t rows, size_t cols)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
||||
// Construction from TensorIndex entry to remove duplication of sizes.
|
||||
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
||||
const TensorInfo* tensor = tensor_index.FindName(name);
|
||||
HWY_ASSERT(tensor != nullptr);
|
||||
cols_ = tensor->shape.back();
|
||||
rows_ = 1;
|
||||
if (tensor->cols_take_extra_dims) {
|
||||
// The columns eat the extra dimensions.
|
||||
rows_ = tensor->shape[0];
|
||||
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
||||
cols_ *= tensor->shape[i];
|
||||
}
|
||||
} else {
|
||||
// The rows eat the extra dimensions.
|
||||
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
||||
rows_ *= tensor->shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <pybind11/stl.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "pybind11/cast.h"
|
||||
|
||||
using gcpp::ActivationType;
|
||||
|
|
@ -16,6 +17,8 @@ using gcpp::PostNormType;
|
|||
using gcpp::PostQKType;
|
||||
using gcpp::QueryScaleType;
|
||||
using gcpp::ResidualType;
|
||||
using gcpp::TensorIndex;
|
||||
using gcpp::TensorInfo;
|
||||
using gcpp::Type;
|
||||
|
||||
namespace pybind11 {
|
||||
|
|
@ -72,6 +75,23 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.value("GEMMA2_2B", Model::GEMMA2_2B)
|
||||
.value("PALIGEMMA_224", Model::PALIGEMMA_224);
|
||||
|
||||
class_<TensorInfo>(py_module, "TensorInfo")
|
||||
.def(init())
|
||||
.def_readwrite("name", &TensorInfo::name)
|
||||
.def_readwrite("source_names", &TensorInfo::source_names)
|
||||
.def_readwrite("preshape", &TensorInfo::preshape)
|
||||
.def_readwrite("axes", &TensorInfo::axes)
|
||||
.def_readwrite("shape", &TensorInfo::shape)
|
||||
.def_readwrite("concat_names", &TensorInfo::concat_names)
|
||||
.def_readwrite("concat_axis", &TensorInfo::concat_axis)
|
||||
.def_readwrite("min_size", &TensorInfo::min_size)
|
||||
.def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus)
|
||||
.def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims);
|
||||
|
||||
class_<TensorIndex>(py_module, "TensorIndex")
|
||||
.def(init<const ModelConfig&, int, int, bool>())
|
||||
.def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path"));
|
||||
|
||||
class_<LayerConfig>(py_module, "LayerConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_dim", &LayerConfig::model_dim)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,565 @@
|
|||
#include "gemma/tensor_index.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Returns the non-layer tensors for the model.
|
||||
std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||
return {
|
||||
TensorInfo{
|
||||
.name = "c_embedding",
|
||||
.source_names = {"embedder/input_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {config.vocab_size, config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "c_final_norm",
|
||||
.source_names = {"final_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "enc_norm_bias",
|
||||
.source_names = {"img/Transformer/encoder_norm/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "enc_norm_scale",
|
||||
.source_names = {"img/Transformer/encoder_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_emb_bias",
|
||||
.source_names = {"img/embedding/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_emb_kernel",
|
||||
.source_names = {"img/embedding/kernel"},
|
||||
.axes = {3, 0, 1, 2},
|
||||
.shape = {config.vit_model_dim, config.patch_width,
|
||||
config.patch_width, 3},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_bias",
|
||||
.source_names = {"img/head/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_kernel",
|
||||
.source_names = {"img/head/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_pos_emb",
|
||||
.source_names = {"img/pos_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {/*1,*/ 256, config.vit_model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given image layer config.
|
||||
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config) {
|
||||
return {
|
||||
// Vit layers.
|
||||
TensorInfo{
|
||||
.name = "attn_out_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.vit_model_dim, layer_config.heads,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "attn_out_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_model_dim},
|
||||
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||
.axes = {2, 0, 3, 1},
|
||||
.shape = {layer_config.heads, 3, layer_config.qkv_dim,
|
||||
config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
||||
.axes = {1, 0, 2},
|
||||
.shape = {layer_config.heads * 3, layer_config.qkv_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_0_w",
|
||||
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.ff_hidden_dim, config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_0_b",
|
||||
.source_names = {"MlpBlock_0/Dense_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_1_w",
|
||||
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.vit_model_dim, layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_1_b",
|
||||
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given LLM layer config.
|
||||
std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
bool reshape_att) {
|
||||
std::vector<TensorInfo> tensors = {
|
||||
TensorInfo{
|
||||
.name = "qkv1_w",
|
||||
.source_names = {"attn/q_einsum/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim, config.model_dim},
|
||||
.concat_names = {"qkv_ein", "qkv2_w"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv2_w",
|
||||
.source_names = {"attn/kv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {2 * layer_config.kv_heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein",
|
||||
.source_names = {"attention_block/proj_q/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.model_dim, layer_config.model_dim},
|
||||
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein",
|
||||
.source_names = {"attention_block/proj_k/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein",
|
||||
.source_names = {"attention_block/proj_v/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein",
|
||||
.source_names = {"attn/qkv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {(layer_config.heads + 2 * layer_config.kv_heads),
|
||||
layer_config.qkv_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "attn_ob",
|
||||
.source_names = {"attention_block/proj_final/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
// Griffin layers.
|
||||
TensorInfo{
|
||||
.name = "gr_lin_x_w",
|
||||
.source_names = {"recurrent_block/linear_x/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_x_b",
|
||||
.source_names = {"recurrent_block/linear_x/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_y_w",
|
||||
.source_names = {"recurrent_block/linear_y/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_y_b",
|
||||
.source_names = {"recurrent_block/linear_y/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_out_w",
|
||||
.source_names = {"recurrent_block/linear_out/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_out_b",
|
||||
.source_names = {"recurrent_block/linear_out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_conv_w",
|
||||
.source_names = {"recurrent_block/conv_1d/w"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_conv_b",
|
||||
.source_names = {"recurrent_block/conv_1d/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr1_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {"gr_gate_w", "gr2_gate_w"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr2_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {2 * layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr1_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {"gr_gate_b", "gr2_gate_b"},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr2_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0, 1},
|
||||
.shape = {2 * layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_a",
|
||||
.source_names = {"recurrent_block/rg_lru/a_param"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
.scaled_softplus = true,
|
||||
},
|
||||
|
||||
TensorInfo{
|
||||
.name = "gating_ein",
|
||||
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
|
||||
"mlp_block/ffw_up/w"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gating1_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gating2_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_w",
|
||||
.source_names = {"mlp/linear/w", "mlp/linear",
|
||||
"mlp_block/ffw_down/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, layer_config.ff_hidden_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "pre_att_ns",
|
||||
.source_names = {"pre_attention_norm/scale",
|
||||
"temporal_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "pre_ff_ns",
|
||||
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "post_att_ns",
|
||||
.source_names = {"post_attention_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "post_ff_ns",
|
||||
.source_names = {"post_ffw_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ffw_gat_b",
|
||||
.source_names = {"mlp_block/ffw_up/b"},
|
||||
.axes = {0},
|
||||
.shape = {2 * layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ffw_out_b",
|
||||
.source_names = {"mlp_block/ffw_down/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
};
|
||||
if (reshape_att) {
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_w",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_ein",
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
} else {
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_ein",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_w",
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||
int img_layer_idx, bool reshape_att)
|
||||
: config_(config),
|
||||
llm_layer_idx_(llm_layer_idx),
|
||||
img_layer_idx_(img_layer_idx) {
|
||||
int layer_idx = std::max(llm_layer_idx_, img_layer_idx_);
|
||||
std::string suffix;
|
||||
if (layer_idx >= 0) {
|
||||
suffix = "_" + std::to_string(layer_idx);
|
||||
}
|
||||
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
||||
tensors_ = ModelTensors(config);
|
||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||
img_layer_idx < config.vit_layer_configs.size()) {
|
||||
const auto& layer_config = config.vit_layer_configs[img_layer_idx];
|
||||
tensors_ = ImageLayerTensors(config, layer_config);
|
||||
} else if (0 <= llm_layer_idx &&
|
||||
llm_layer_idx < config.layer_configs.size()) {
|
||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
|
||||
}
|
||||
for (size_t i = 0; i < tensors_.size(); ++i) {
|
||||
std::string key = tensors_[i].name + suffix;
|
||||
name_map_.insert({key, i});
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfo TensorIndex::GetTensorInfo(const std::string& path) const {
|
||||
for (const auto& tensor : tensors_) {
|
||||
for (const auto& source_name : tensor.source_names) {
|
||||
auto pos = path.rfind(source_name);
|
||||
if (pos != std::string::npos && path.size() == pos + source_name.size())
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
return TensorInfo();
|
||||
}
|
||||
|
||||
const TensorInfo* TensorIndex::FindName(const std::string& name) const {
|
||||
std::string name_to_find = name;
|
||||
if (!std::isdigit(name[name.size() - 1])) {
|
||||
if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) {
|
||||
name_to_find = name + "_" + std::to_string(img_layer_idx_);
|
||||
} else if (llm_layer_idx_ >= 0) {
|
||||
name_to_find = name + "_" + std::to_string(llm_layer_idx_);
|
||||
}
|
||||
}
|
||||
auto it = name_map_.find(name_to_find);
|
||||
if (it == name_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &tensors_[it->second];
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Universal tensor information. Holds enough information to construct a
|
||||
// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the
|
||||
// tensor from the python model with necessary transpose/reshape info.
|
||||
struct TensorInfo {
|
||||
// The name of the tensor in the sbs file
|
||||
std::string name;
|
||||
// Strings to match to the end of the name of the tensor in the python model.
|
||||
std::vector<std::string> source_names;
|
||||
// Initial reshape shape. Use only as a last resort when input may have
|
||||
// dimensions combined that need to be split before the transpose, as it
|
||||
// defeats the post-transpose shape checking. Normally empty.
|
||||
std::vector<size_t> preshape;
|
||||
// Transpose axes arg. If the input tensor has more dimensions than axes,
|
||||
// then leading dimensions are collapsed until the number of axes matches.
|
||||
std::vector<size_t> axes;
|
||||
// Expected final shape of the tensor after reshape/transpose.
|
||||
// Note that this is the shape of the tensor during export,
|
||||
// not the shape of the tensor in the sbs file, as the sbs file
|
||||
// is restricted to 2D tensors. With few exceptions, the sbs file
|
||||
// tensor rows gather all the excess dimensions. See cols_take_extra_dims.
|
||||
std::vector<size_t> shape;
|
||||
// List of names to concatenate with, used only if multiple tensors are
|
||||
// concatenated into one. The first tensor in the concatenation should have
|
||||
// concat names thus: The first name is the name of the result, and the
|
||||
// tensors with the remaining names are concatenated after this.
|
||||
// The remaining tensors to be concatenated should have just a single
|
||||
// empty string in concat_names to indicate that they have been consumed.
|
||||
std::vector<std::string> concat_names;
|
||||
// Axis at which to concatenate.
|
||||
size_t concat_axis = 0;
|
||||
// The minimum compression weight type for this tensor. The default is
|
||||
// kNUQ, which provides maximum compression. Other values such as kBF16
|
||||
// or kF32 can be used to limit the compression to a specific type.
|
||||
Type min_size = Type::kNUQ;
|
||||
// Whether to apply scaled softplus to the data.
|
||||
bool scaled_softplus = false;
|
||||
// Whether the columns or the rows take any extra dimensions.
|
||||
// If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30].
|
||||
// If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30].
|
||||
bool cols_take_extra_dims = false;
|
||||
};
|
||||
|
||||
// Universal index of tensor information, which can be built for a specific
|
||||
// layer_idx.
|
||||
class TensorIndex {
|
||||
public:
|
||||
// Builds a list of TensorInfo for the given layer_idx.
|
||||
// If reshape_att is true, the attn_vec_einsum tensor is reshaped.
|
||||
TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx,
|
||||
bool reshape_att);
|
||||
~TensorIndex() = default;
|
||||
|
||||
// Returns the TensorInfo whose source_name matches the end of the given path,
|
||||
// or an empty TensorInfo if not found.
|
||||
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||
TensorInfo GetTensorInfo(const std::string& path) const;
|
||||
|
||||
// Returns the TensorInfo for the given tensor name, for concise construction
|
||||
// of ModelWeightsPtrs/LayerWeightsPtrs.
|
||||
const TensorInfo* FindName(const std::string& name) const;
|
||||
|
||||
private:
|
||||
// Config that was used to build the tensor index.
|
||||
const ModelConfig& config_;
|
||||
// Layer that this tensor index is for - either LLM or image.
|
||||
int llm_layer_idx_;
|
||||
int img_layer_idx_;
|
||||
// List of tensor information for this layer.
|
||||
std::vector<TensorInfo> tensors_;
|
||||
// Map from tensor name to index in tensors_.
|
||||
std::unordered_map<std::string, size_t> name_map_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
#include "gemma/tensor_index.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Tests that each tensor in the model can be found by exactly one TensorIndex,
|
||||
// and that the TensorIndex returns the correct shape and name for the tensor,
|
||||
// for all models.
|
||||
TEST(TensorIndexTest, FindName) {
|
||||
hwy::ThreadPool pool(4);
|
||||
for (Model model : kAllModels) {
|
||||
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
|
||||
ModelConfig config = ConfigFromModel(model);
|
||||
std::vector<TensorIndex> tensor_indexes;
|
||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||
/*img_layer_idx=*/-1,
|
||||
/*split_and_reshape=*/false);
|
||||
for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size();
|
||||
++llm_layer_idx) {
|
||||
tensor_indexes.emplace_back(config, static_cast<int>(llm_layer_idx),
|
||||
/*img_layer_idx=*/-1,
|
||||
/*split_and_reshape=*/false);
|
||||
}
|
||||
for (size_t img_layer_idx = 0;
|
||||
img_layer_idx < config.vit_layer_configs.size();
|
||||
++img_layer_idx) {
|
||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||
static_cast<int>(img_layer_idx),
|
||||
/*split_and_reshape=*/false);
|
||||
}
|
||||
// For each tensor in any model, exactly one TensorIndex should find it.
|
||||
ModelWeightsPtrs<SfpStream> weights(config, pool);
|
||||
ModelWeightsPtrs<SfpStream>::ForEachTensor(
|
||||
{&weights}, ForEachType::kInitNoToc,
|
||||
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
int num_found = 0;
|
||||
const MatPtr& tensor = *tensors[0];
|
||||
for (const auto& tensor_index : tensor_indexes) {
|
||||
// Skip the type marker prefix, but we want the layer index suffix.
|
||||
std::string name_to_find(name + 1, strlen(name) - 1);
|
||||
const TensorInfo* info = tensor_index.FindName(name_to_find);
|
||||
if (info != nullptr) {
|
||||
// Test that the MatPtr can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
|
||||
EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
|
||||
++num_found;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(num_found, 1) << " for tensor " << name;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -57,8 +57,8 @@ template <class Weight>
|
|||
struct LayerWeightsPtrs {
|
||||
// Large data is constructed separately.
|
||||
explicit LayerWeightsPtrs(const LayerConfig& config)
|
||||
: attn_vec_einsum_w("att_ein", config.model_dim,
|
||||
config.heads * config.qkv_dim),
|
||||
: attn_vec_einsum_w("att_ein", config.heads * config.model_dim,
|
||||
config.qkv_dim),
|
||||
qkv_einsum_w("qkv_ein",
|
||||
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
|
||||
config.model_dim),
|
||||
|
|
@ -86,8 +86,8 @@ struct LayerWeightsPtrs {
|
|||
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
|
||||
.a = {"gr_a", 1, config.griffin_dim}}),
|
||||
// MultiHeadDotProductAttention.
|
||||
vit({.attn_out_w = {"attn_out_w", config.heads * config.qkv_dim,
|
||||
config.model_dim},
|
||||
vit({.attn_out_w = {"attn_out_w", config.model_dim,
|
||||
config.heads * config.qkv_dim},
|
||||
.attn_out_b = {"attn_out_b", 1, config.model_dim},
|
||||
.qkv_einsum_w = {"qkv_ein_w",
|
||||
(config.heads + 2 * config.kv_heads) *
|
||||
|
|
@ -349,9 +349,8 @@ struct ModelWeightsPtrs {
|
|||
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel",
|
||||
config.patch_width * config.patch_width * 3,
|
||||
config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
|
||||
config.patch_width * config.patch_width * 3),
|
||||
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
|
||||
vit_img_head_bias("img_head_bias", 1, config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.model_dim,
|
||||
|
|
|
|||
Loading…
Reference in New Issue