Added tensor_index as a single source of truth on tensor shapes/sources and transformations

PiperOrigin-RevId: 697903886
This commit is contained in:
Ray Smith 2024-11-19 00:25:01 -08:00 committed by Copybara-Service
parent 7d685a267f
commit 73640d2521
9 changed files with 798 additions and 8 deletions

View File

@ -194,13 +194,15 @@ cc_library(
srcs = [
"gemma/common.cc",
"gemma/configs.cc",
"gemma/tensor_index.cc",
],
hdrs = [
"gemma/common.h",
"gemma/configs.h",
"gemma/tensor_index.h",
],
deps = [
"//compression:compress",
"//compression:sfp",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
@ -215,6 +217,20 @@ cc_test(
],
)
cc_test(
name = "tensor_index_test",
srcs = ["gemma/tensor_index_test.cc"],
deps = [
":basics",
":common",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_library(
name = "weights",
srcs = ["gemma/weights.cc"],

View File

@ -84,6 +84,8 @@ set(SOURCES
gemma/instantiations/sfp.cc
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/tensor_index.cc
gemma/tensor_index.h
gemma/tokenizer.cc
gemma/tokenizer.h
gemma/weights.cc
@ -157,6 +159,7 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc
compression/sfp_test.cc
evals/gemma_test.cc
gemma/tensor_index_test.cc
ops/dot_test.cc
ops/gemma_matvec_test.cc
ops/matmul_test.cc

View File

@ -201,6 +201,7 @@ cc_library(
":sfp",
"//:allocator",
"//:basics",
"//:common",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",

View File

@ -17,6 +17,7 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#include "hwy/base.h"
#define COMPRESS_STATS 0
#include <stddef.h>
@ -33,6 +34,7 @@
#include "compression/blob_store.h"
#include "compression/io.h"
#include "compression/shared.h"
#include "gemma/tensor_index.h"
#include "util/basics.h"
// IWYU pragma: end_exports
#include "util/allocator.h"
@ -211,6 +213,26 @@ class MatPtrT : public MatPtr {
// Full constructor for dynamic sizing.
MatPtrT(const std::string& name, size_t rows, size_t cols)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
// Construction from TensorIndex entry to remove duplication of sizes.
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
const TensorInfo* tensor = tensor_index.FindName(name);
HWY_ASSERT(tensor != nullptr);
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
}
}
}
// Copying allowed as the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {}

View File

@ -4,6 +4,7 @@
#include <pybind11/stl.h>
#include "compression/shared.h"
#include "gemma/tensor_index.h"
#include "pybind11/cast.h"
using gcpp::ActivationType;
@ -16,6 +17,8 @@ using gcpp::PostNormType;
using gcpp::PostQKType;
using gcpp::QueryScaleType;
using gcpp::ResidualType;
using gcpp::TensorIndex;
using gcpp::TensorInfo;
using gcpp::Type;
namespace pybind11 {
@ -72,6 +75,23 @@ PYBIND11_MODULE(configs, py_module) {
.value("GEMMA2_2B", Model::GEMMA2_2B)
.value("PALIGEMMA_224", Model::PALIGEMMA_224);
class_<TensorInfo>(py_module, "TensorInfo")
.def(init())
.def_readwrite("name", &TensorInfo::name)
.def_readwrite("source_names", &TensorInfo::source_names)
.def_readwrite("preshape", &TensorInfo::preshape)
.def_readwrite("axes", &TensorInfo::axes)
.def_readwrite("shape", &TensorInfo::shape)
.def_readwrite("concat_names", &TensorInfo::concat_names)
.def_readwrite("concat_axis", &TensorInfo::concat_axis)
.def_readwrite("min_size", &TensorInfo::min_size)
.def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus)
.def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims);
class_<TensorIndex>(py_module, "TensorIndex")
.def(init<const ModelConfig&, int, int, bool>())
.def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path"));
class_<LayerConfig>(py_module, "LayerConfig")
.def(init())
.def_readwrite("model_dim", &LayerConfig::model_dim)

565
gemma/tensor_index.cc Normal file
View File

@ -0,0 +1,565 @@
#include "gemma/tensor_index.h"
#include <stddef.h>
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>
#include "compression/shared.h"
#include "gemma/configs.h"
namespace gcpp {
namespace {
// Returns the non-layer tensors for the model.
std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
return {
TensorInfo{
.name = "c_embedding",
.source_names = {"embedder/input_embedding"},
.axes = {0, 1},
.shape = {config.vocab_size, config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "c_final_norm",
.source_names = {"final_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "enc_norm_bias",
.source_names = {"img/Transformer/encoder_norm/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "enc_norm_scale",
.source_names = {"img/Transformer/encoder_norm/scale"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "img_emb_bias",
.source_names = {"img/embedding/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "img_emb_kernel",
.source_names = {"img/embedding/kernel"},
.axes = {3, 0, 1, 2},
.shape = {config.vit_model_dim, config.patch_width,
config.patch_width, 3},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
},
TensorInfo{
.name = "img_head_bias",
.source_names = {"img/head/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "img_head_kernel",
.source_names = {"img/head/kernel"},
.axes = {1, 0},
.shape = {config.model_dim, config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "img_pos_emb",
.source_names = {"img/pos_embedding"},
.axes = {0, 1},
.shape = {/*1,*/ 256, config.vit_model_dim},
.min_size = Type::kF32,
},
};
}
// Returns the tensors for the given image layer config.
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config) {
return {
// Vit layers.
TensorInfo{
.name = "attn_out_w",
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
.axes = {2, 0, 1},
.shape = {config.vit_model_dim, layer_config.heads,
layer_config.qkv_dim},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
},
TensorInfo{
.name = "attn_out_b",
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "q_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim},
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
.concat_axis = 1,
.min_size = Type::kBF16,
},
TensorInfo{
.name = "k_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "v_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "qkv_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
.axes = {2, 0, 3, 1},
.shape = {layer_config.heads, 3, layer_config.qkv_dim,
config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "q_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
.axes = {0, 1},
.shape = {layer_config.heads, layer_config.qkv_dim},
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
.concat_axis = 1,
.min_size = Type::kF32,
},
TensorInfo{
.name = "k_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
.axes = {0, 1},
.shape = {layer_config.heads, layer_config.qkv_dim},
.concat_names = {""},
.min_size = Type::kF32,
},
TensorInfo{
.name = "v_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
.axes = {0, 1},
.shape = {layer_config.heads, layer_config.qkv_dim},
.concat_names = {""},
.min_size = Type::kF32,
},
TensorInfo{
.name = "qkv_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
.axes = {1, 0, 2},
.shape = {layer_config.heads * 3, layer_config.qkv_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "linear_0_w",
.source_names = {"MlpBlock_0/Dense_0/kernel"},
.axes = {1, 0},
.shape = {layer_config.ff_hidden_dim, config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "linear_0_b",
.source_names = {"MlpBlock_0/Dense_0/bias"},
.axes = {0},
.shape = {layer_config.ff_hidden_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "linear_1_w",
.source_names = {"MlpBlock_0/Dense_1/kernel"},
.axes = {1, 0},
.shape = {config.vit_model_dim, layer_config.ff_hidden_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "linear_1_b",
.source_names = {"MlpBlock_0/Dense_1/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "ln_0_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_0_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
.axes = {0},
.shape = {config.vit_model_dim},
.min_size = Type::kBF16,
},
};
}
// Returns the tensors for the given LLM layer config.
std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
bool reshape_att) {
std::vector<TensorInfo> tensors = {
TensorInfo{
.name = "qkv1_w",
.source_names = {"attn/q_einsum/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads, layer_config.qkv_dim, config.model_dim},
.concat_names = {"qkv_ein", "qkv2_w"},
},
TensorInfo{
.name = "qkv2_w",
.source_names = {"attn/kv_einsum/w"},
.axes = {1, 0, 3, 2},
.shape = {2 * layer_config.kv_heads, layer_config.qkv_dim,
config.model_dim},
.concat_names = {""},
},
TensorInfo{
.name = "q_ein",
.source_names = {"attention_block/proj_q/kernel"},
.axes = {1, 0},
.shape = {layer_config.model_dim, layer_config.model_dim},
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
},
TensorInfo{
.name = "k_ein",
.source_names = {"attention_block/proj_k/kernel"},
.axes = {1, 0},
.shape = {layer_config.qkv_dim, layer_config.model_dim},
.concat_names = {""},
},
TensorInfo{
.name = "v_ein",
.source_names = {"attention_block/proj_v/kernel"},
.axes = {1, 0},
.shape = {layer_config.qkv_dim, layer_config.model_dim},
.concat_names = {""},
},
TensorInfo{
.name = "qkv_ein",
.source_names = {"attn/qkv_einsum/w"},
.axes = {1, 0, 3, 2},
.shape = {(layer_config.heads + 2 * layer_config.kv_heads),
layer_config.qkv_dim, config.model_dim},
},
TensorInfo{
.name = "attn_ob",
.source_names = {"attention_block/proj_final/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
// Griffin layers.
TensorInfo{
.name = "gr_lin_x_w",
.source_names = {"recurrent_block/linear_x/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
},
TensorInfo{
.name = "gr_lin_x_b",
.source_names = {"recurrent_block/linear_x/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_lin_y_w",
.source_names = {"recurrent_block/linear_y/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
},
TensorInfo{
.name = "gr_lin_y_b",
.source_names = {"recurrent_block/linear_y/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_lin_out_w",
.source_names = {"recurrent_block/linear_out/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
},
TensorInfo{
.name = "gr_lin_out_b",
.source_names = {"recurrent_block/linear_out/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_conv_w",
.source_names = {"recurrent_block/conv_1d/w"},
.axes = {0, 1},
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_conv_b",
.source_names = {"recurrent_block/conv_1d/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr1_gate_w",
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {"gr_gate_w", "gr2_gate_w"},
},
TensorInfo{
.name = "gr2_gate_w",
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {""},
},
TensorInfo{
.name = "gr_gate_w",
.source_names = {"recurrent_block/rg_lru/gate/w"},
.axes = {0, 2, 1},
.shape = {2 * layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
},
TensorInfo{
.name = "gr1_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {"gr_gate_b", "gr2_gate_b"},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr2_gate_b",
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {""},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0, 1},
.shape = {2 * layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_a",
.source_names = {"recurrent_block/rg_lru/a_param"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
.scaled_softplus = true,
},
TensorInfo{
.name = "gating_ein",
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
"mlp_block/ffw_up/w"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
},
TensorInfo{
.name = "gating1_w",
.source_names = {"none"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {layer_config.ff_hidden_dim, config.model_dim},
},
TensorInfo{
.name = "gating2_w",
.source_names = {"none"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {layer_config.ff_hidden_dim, config.model_dim},
},
TensorInfo{
.name = "linear_w",
.source_names = {"mlp/linear/w", "mlp/linear",
"mlp_block/ffw_down/kernel"},
.axes = {1, 0},
.shape = {config.model_dim, layer_config.ff_hidden_dim},
},
TensorInfo{
.name = "pre_att_ns",
.source_names = {"pre_attention_norm/scale",
"temporal_pre_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "pre_ff_ns",
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "post_att_ns",
.source_names = {"post_attention_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "post_ff_ns",
.source_names = {"post_ffw_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ffw_gat_b",
.source_names = {"mlp_block/ffw_up/b"},
.axes = {0},
.shape = {2 * layer_config.ff_hidden_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "ffw_out_b",
.source_names = {"mlp_block/ffw_down/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
};
if (reshape_att) {
tensors.push_back(TensorInfo{
.name = "att_w",
.source_names = {"attn/attn_vec_einsum/w",
"attention_block/proj_final/kernel"},
.axes = {2, 0, 1},
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
tensors.push_back(TensorInfo{
.name = "att_ein",
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
});
} else {
tensors.push_back(TensorInfo{
.name = "att_ein",
.source_names = {"attn/attn_vec_einsum/w",
"attention_block/proj_final/kernel"},
.preshape = {layer_config.heads, layer_config.qkv_dim,
config.model_dim},
.axes = {0, 2, 1},
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
});
tensors.push_back(TensorInfo{
.name = "att_w",
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
}
return tensors;
}
} // namespace
TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
int img_layer_idx, bool reshape_att)
: config_(config),
llm_layer_idx_(llm_layer_idx),
img_layer_idx_(img_layer_idx) {
int layer_idx = std::max(llm_layer_idx_, img_layer_idx_);
std::string suffix;
if (layer_idx >= 0) {
suffix = "_" + std::to_string(layer_idx);
}
if (llm_layer_idx < 0 && img_layer_idx < 0) {
tensors_ = ModelTensors(config);
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
img_layer_idx < config.vit_layer_configs.size()) {
const auto& layer_config = config.vit_layer_configs[img_layer_idx];
tensors_ = ImageLayerTensors(config, layer_config);
} else if (0 <= llm_layer_idx &&
llm_layer_idx < config.layer_configs.size()) {
const auto& layer_config = config.layer_configs[llm_layer_idx];
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
}
for (size_t i = 0; i < tensors_.size(); ++i) {
std::string key = tensors_[i].name + suffix;
name_map_.insert({key, i});
}
}
TensorInfo TensorIndex::GetTensorInfo(const std::string& path) const {
for (const auto& tensor : tensors_) {
for (const auto& source_name : tensor.source_names) {
auto pos = path.rfind(source_name);
if (pos != std::string::npos && path.size() == pos + source_name.size())
return tensor;
}
}
return TensorInfo();
}
const TensorInfo* TensorIndex::FindName(const std::string& name) const {
std::string name_to_find = name;
if (!std::isdigit(name[name.size() - 1])) {
if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) {
name_to_find = name + "_" + std::to_string(img_layer_idx_);
} else if (llm_layer_idx_ >= 0) {
name_to_find = name + "_" + std::to_string(llm_layer_idx_);
}
}
auto it = name_map_.find(name_to_find);
if (it == name_map_.end()) {
return nullptr;
}
return &tensors_[it->second];
}
} // namespace gcpp

91
gemma/tensor_index.h Normal file
View File

@ -0,0 +1,91 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
#include <stddef.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "compression/shared.h"
#include "gemma/configs.h"
namespace gcpp {
// Universal tensor information. Holds enough information to construct a
// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the
// tensor from the python model with necessary transpose/reshape info.
struct TensorInfo {
// The name of the tensor in the sbs file
std::string name;
// Strings to match to the end of the name of the tensor in the python model.
std::vector<std::string> source_names;
// Initial reshape shape. Use only as a last resort when input may have
// dimensions combined that need to be split before the transpose, as it
// defeats the post-transpose shape checking. Normally empty.
std::vector<size_t> preshape;
// Transpose axes arg. If the input tensor has more dimensions than axes,
// then leading dimensions are collapsed until the number of axes matches.
std::vector<size_t> axes;
// Expected final shape of the tensor after reshape/transpose.
// Note that this is the shape of the tensor during export,
// not the shape of the tensor in the sbs file, as the sbs file
// is restricted to 2D tensors. With few exceptions, the sbs file
// tensor rows gather all the excess dimensions. See cols_take_extra_dims.
std::vector<size_t> shape;
// List of names to concatenate with, used only if multiple tensors are
// concatenated into one. The first tensor in the concatenation should have
// concat names thus: The first name is the name of the result, and the
// tensors with the remaining names are concatenated after this.
// The remaining tensors to be concatenated should have just a single
// empty string in concat_names to indicate that they have been consumed.
std::vector<std::string> concat_names;
// Axis at which to concatenate.
size_t concat_axis = 0;
// The minimum compression weight type for this tensor. The default is
// kNUQ, which provides maximum compression. Other values such as kBF16
// or kF32 can be used to limit the compression to a specific type.
Type min_size = Type::kNUQ;
// Whether to apply scaled softplus to the data.
bool scaled_softplus = false;
// Whether the columns or the rows take any extra dimensions.
// If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30].
// If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30].
bool cols_take_extra_dims = false;
};
// Universal index of tensor information, which can be built for a specific
// layer_idx.
class TensorIndex {
public:
// Builds a list of TensorInfo for the given layer_idx.
// If reshape_att is true, the attn_vec_einsum tensor is reshaped.
TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx,
bool reshape_att);
~TensorIndex() = default;
// Returns the TensorInfo whose source_name matches the end of the given path,
// or an empty TensorInfo if not found.
// NOTE: that the returned TensorInfo is a copy, so that the source
// TensorIndex can be destroyed without affecting the returned TensorInfo.
TensorInfo GetTensorInfo(const std::string& path) const;
// Returns the TensorInfo for the given tensor name, for concise construction
// of ModelWeightsPtrs/LayerWeightsPtrs.
const TensorInfo* FindName(const std::string& name) const;
private:
// Config that was used to build the tensor index.
const ModelConfig& config_;
// Layer that this tensor index is for - either LLM or image.
int llm_layer_idx_;
int img_layer_idx_;
// List of tensor information for this layer.
std::vector<TensorInfo> tensors_;
// Map from tensor name to index in tensors_.
std::unordered_map<std::string, size_t> name_map_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_

View File

@ -0,0 +1,73 @@
#include "gemma/tensor_index.h"
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "compression/compress.h"
#include "compression/shared.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/basics.h"
#include "hwy/aligned_allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
// Tests that each tensor in the model can be found by exactly one TensorIndex,
// and that the TensorIndex returns the correct shape and name for the tensor,
// for all models.
TEST(TensorIndexTest, FindName) {
hwy::ThreadPool pool(4);
for (Model model : kAllModels) {
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
ModelConfig config = ConfigFromModel(model);
std::vector<TensorIndex> tensor_indexes;
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
/*img_layer_idx=*/-1,
/*split_and_reshape=*/false);
for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size();
++llm_layer_idx) {
tensor_indexes.emplace_back(config, static_cast<int>(llm_layer_idx),
/*img_layer_idx=*/-1,
/*split_and_reshape=*/false);
}
for (size_t img_layer_idx = 0;
img_layer_idx < config.vit_layer_configs.size();
++img_layer_idx) {
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
static_cast<int>(img_layer_idx),
/*split_and_reshape=*/false);
}
// For each tensor in any model, exactly one TensorIndex should find it.
ModelWeightsPtrs<SfpStream> weights(config, pool);
ModelWeightsPtrs<SfpStream>::ForEachTensor(
{&weights}, ForEachType::kInitNoToc,
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
int num_found = 0;
const MatPtr& tensor = *tensors[0];
for (const auto& tensor_index : tensor_indexes) {
// Skip the type marker prefix, but we want the layer index suffix.
std::string name_to_find(name + 1, strlen(name) - 1);
const TensorInfo* info = tensor_index.FindName(name_to_find);
if (info != nullptr) {
// Test that the MatPtr can be constructed from the TensorInfo,
// and that the dimensions match.
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name;
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
++num_found;
}
}
EXPECT_EQ(num_found, 1) << " for tensor " << name;
});
}
}
} // namespace
} // namespace gcpp

View File

@ -57,8 +57,8 @@ template <class Weight>
struct LayerWeightsPtrs {
// Large data is constructed separately.
explicit LayerWeightsPtrs(const LayerConfig& config)
: attn_vec_einsum_w("att_ein", config.model_dim,
config.heads * config.qkv_dim),
: attn_vec_einsum_w("att_ein", config.heads * config.model_dim,
config.qkv_dim),
qkv_einsum_w("qkv_ein",
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
config.model_dim),
@ -86,8 +86,8 @@ struct LayerWeightsPtrs {
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
.a = {"gr_a", 1, config.griffin_dim}}),
// MultiHeadDotProductAttention.
vit({.attn_out_w = {"attn_out_w", config.heads * config.qkv_dim,
config.model_dim},
vit({.attn_out_w = {"attn_out_w", config.model_dim,
config.heads * config.qkv_dim},
.attn_out_b = {"attn_out_b", 1, config.model_dim},
.qkv_einsum_w = {"qkv_ein_w",
(config.heads + 2 * config.kv_heads) *
@ -349,9 +349,8 @@ struct ModelWeightsPtrs {
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel",
config.patch_width * config.patch_width * 3,
config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
config.patch_width * config.patch_width * 3),
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.model_dim,