diff --git a/BUILD.bazel b/BUILD.bazel index e5a7939..669caf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -194,13 +194,15 @@ cc_library( srcs = [ "gemma/common.cc", "gemma/configs.cc", + "gemma/tensor_index.cc", ], hdrs = [ "gemma/common.h", "gemma/configs.h", + "gemma/tensor_index.h", ], deps = [ - "//compression:compress", + "//compression:sfp", "@highway//:hwy", # base.h "@highway//:thread_pool", ], @@ -215,6 +217,20 @@ cc_test( ], ) +cc_test( + name = "tensor_index_test", + srcs = ["gemma/tensor_index_test.cc"], + deps = [ + ":basics", + ":common", + ":weights", + "@googletest//:gtest_main", + "//compression:compress", + "@highway//:hwy", + "@highway//:thread_pool", + ], +) + cc_library( name = "weights", srcs = ["gemma/weights.cc"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 876da6c..4d03da0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,8 @@ set(SOURCES gemma/instantiations/sfp.cc gemma/kv_cache.cc gemma/kv_cache.h + gemma/tensor_index.cc + gemma/tensor_index.h gemma/tokenizer.cc gemma/tokenizer.h gemma/weights.cc @@ -157,6 +159,7 @@ set(GEMMA_TEST_FILES compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc + gemma/tensor_index_test.cc ops/dot_test.cc ops/gemma_matvec_test.cc ops/matmul_test.cc diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index 585a74b..c7cec0a 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -201,6 +201,7 @@ cc_library( ":sfp", "//:allocator", "//:basics", + "//:common", "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", diff --git a/compression/compress.h b/compression/compress.h index 9050d53..ff64b49 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -17,6 +17,7 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ +#include "hwy/base.h" #define COMPRESS_STATS 0 #include @@ -33,6 +34,7 @@ #include "compression/blob_store.h" #include "compression/io.h" #include "compression/shared.h" +#include "gemma/tensor_index.h" #include "util/basics.h" // IWYU pragma: end_exports #include "util/allocator.h" @@ -211,6 +213,26 @@ class MatPtrT : public MatPtr { // Full constructor for dynamic sizing. MatPtrT(const std::string& name, size_t rows, size_t cols) : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} + // Construction from TensorIndex entry to remove duplication of sizes. + MatPtrT(const std::string& name, const TensorIndex& tensor_index) + : MatPtr(name, TypeEnum(), sizeof(MatT), 0, 0) { + const TensorInfo* tensor = tensor_index.FindName(name); + HWY_ASSERT(tensor != nullptr); + cols_ = tensor->shape.back(); + rows_ = 1; + if (tensor->cols_take_extra_dims) { + // The columns eat the extra dimensions. + rows_ = tensor->shape[0]; + for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { + cols_ *= tensor->shape[i]; + } + } else { + // The rows eat the extra dimensions. + for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { + rows_ *= tensor->shape[i]; + } + } + } // Copying allowed as the metadata is small. MatPtrT(const MatPtr& other) : MatPtr(other) {} diff --git a/gemma/python/configs.cc b/gemma/python/configs.cc index aff93cc..8c37840 100644 --- a/gemma/python/configs.cc +++ b/gemma/python/configs.cc @@ -4,6 +4,7 @@ #include #include "compression/shared.h" +#include "gemma/tensor_index.h" #include "pybind11/cast.h" using gcpp::ActivationType; @@ -16,6 +17,8 @@ using gcpp::PostNormType; using gcpp::PostQKType; using gcpp::QueryScaleType; using gcpp::ResidualType; +using gcpp::TensorIndex; +using gcpp::TensorInfo; using gcpp::Type; namespace pybind11 { @@ -72,6 +75,23 @@ PYBIND11_MODULE(configs, py_module) { .value("GEMMA2_2B", Model::GEMMA2_2B) .value("PALIGEMMA_224", Model::PALIGEMMA_224); + class_(py_module, "TensorInfo") + .def(init()) + .def_readwrite("name", &TensorInfo::name) + .def_readwrite("source_names", &TensorInfo::source_names) + .def_readwrite("preshape", &TensorInfo::preshape) + .def_readwrite("axes", &TensorInfo::axes) + .def_readwrite("shape", &TensorInfo::shape) + .def_readwrite("concat_names", &TensorInfo::concat_names) + .def_readwrite("concat_axis", &TensorInfo::concat_axis) + .def_readwrite("min_size", &TensorInfo::min_size) + .def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus) + .def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims); + + class_(py_module, "TensorIndex") + .def(init()) + .def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path")); + class_(py_module, "LayerConfig") .def(init()) .def_readwrite("model_dim", &LayerConfig::model_dim) diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc new file mode 100644 index 0000000..9f44a9c --- /dev/null +++ b/gemma/tensor_index.cc @@ -0,0 +1,565 @@ +#include "gemma/tensor_index.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "compression/shared.h" +#include "gemma/configs.h" + +namespace gcpp { +namespace { + +// Returns the non-layer tensors for the model. +std::vector ModelTensors(const ModelConfig& config) { + return { + TensorInfo{ + .name = "c_embedding", + .source_names = {"embedder/input_embedding"}, + .axes = {0, 1}, + .shape = {config.vocab_size, config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "c_final_norm", + .source_names = {"final_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "enc_norm_bias", + .source_names = {"img/Transformer/encoder_norm/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "enc_norm_scale", + .source_names = {"img/Transformer/encoder_norm/scale"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "img_emb_bias", + .source_names = {"img/embedding/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "img_emb_kernel", + .source_names = {"img/embedding/kernel"}, + .axes = {3, 0, 1, 2}, + .shape = {config.vit_model_dim, config.patch_width, + config.patch_width, 3}, + .min_size = Type::kBF16, + .cols_take_extra_dims = true, + }, + TensorInfo{ + .name = "img_head_bias", + .source_names = {"img/head/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "img_head_kernel", + .source_names = {"img/head/kernel"}, + .axes = {1, 0}, + .shape = {config.model_dim, config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "img_pos_emb", + .source_names = {"img/pos_embedding"}, + .axes = {0, 1}, + .shape = {/*1,*/ 256, config.vit_model_dim}, + .min_size = Type::kF32, + }, + }; +} + +// Returns the tensors for the given image layer config. +std::vector ImageLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config) { + return { + // Vit layers. + TensorInfo{ + .name = "attn_out_w", + .source_names = {"MultiHeadDotProductAttention_0/out/kernel"}, + .axes = {2, 0, 1}, + .shape = {config.vit_model_dim, layer_config.heads, + layer_config.qkv_dim}, + .min_size = Type::kBF16, + .cols_take_extra_dims = true, + }, + TensorInfo{ + .name = "attn_out_b", + .source_names = {"MultiHeadDotProductAttention_0/out/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "q_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/query/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_model_dim}, + .concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"}, + .concat_axis = 1, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "k_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/key/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_model_dim}, + .concat_names = {""}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "v_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/value/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_model_dim}, + .concat_names = {""}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "qkv_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"}, + .axes = {2, 0, 3, 1}, + .shape = {layer_config.heads, 3, layer_config.qkv_dim, + config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "q_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/query/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"}, + .concat_axis = 1, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "k_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/key/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "v_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/value/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "qkv_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/qkv/bias"}, + .axes = {1, 0, 2}, + .shape = {layer_config.heads * 3, layer_config.qkv_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "linear_0_w", + .source_names = {"MlpBlock_0/Dense_0/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.ff_hidden_dim, config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "linear_0_b", + .source_names = {"MlpBlock_0/Dense_0/bias"}, + .axes = {0}, + .shape = {layer_config.ff_hidden_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "linear_1_w", + .source_names = {"MlpBlock_0/Dense_1/kernel"}, + .axes = {1, 0}, + .shape = {config.vit_model_dim, layer_config.ff_hidden_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "linear_1_b", + .source_names = {"MlpBlock_0/Dense_1/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "ln_0_bias", + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ln_0_scale", + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ln_1_bias", + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ln_1_scale", + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + }; +} + +// Returns the tensors for the given LLM layer config. +std::vector LLMLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config, + bool reshape_att) { + std::vector tensors = { + TensorInfo{ + .name = "qkv1_w", + .source_names = {"attn/q_einsum/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim, config.model_dim}, + .concat_names = {"qkv_ein", "qkv2_w"}, + }, + TensorInfo{ + .name = "qkv2_w", + .source_names = {"attn/kv_einsum/w"}, + .axes = {1, 0, 3, 2}, + .shape = {2 * layer_config.kv_heads, layer_config.qkv_dim, + config.model_dim}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "q_ein", + .source_names = {"attention_block/proj_q/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.model_dim, layer_config.model_dim}, + .concat_names = {"qkv_ein", "k_ein", "v_ein"}, + }, + TensorInfo{ + .name = "k_ein", + .source_names = {"attention_block/proj_k/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.qkv_dim, layer_config.model_dim}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "v_ein", + .source_names = {"attention_block/proj_v/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.qkv_dim, layer_config.model_dim}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "qkv_ein", + .source_names = {"attn/qkv_einsum/w"}, + .axes = {1, 0, 3, 2}, + .shape = {(layer_config.heads + 2 * layer_config.kv_heads), + layer_config.qkv_dim, config.model_dim}, + }, + TensorInfo{ + .name = "attn_ob", + .source_names = {"attention_block/proj_final/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }, + // Griffin layers. + TensorInfo{ + .name = "gr_lin_x_w", + .source_names = {"recurrent_block/linear_x/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }, + TensorInfo{ + .name = "gr_lin_x_b", + .source_names = {"recurrent_block/linear_x/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_lin_y_w", + .source_names = {"recurrent_block/linear_y/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }, + TensorInfo{ + .name = "gr_lin_y_b", + .source_names = {"recurrent_block/linear_y/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_lin_out_w", + .source_names = {"recurrent_block/linear_out/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }, + TensorInfo{ + .name = "gr_lin_out_b", + .source_names = {"recurrent_block/linear_out/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_conv_w", + .source_names = {"recurrent_block/conv_1d/w"}, + .axes = {0, 1}, + .shape = {layer_config.conv1d_width, layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_conv_b", + .source_names = {"recurrent_block/conv_1d/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr1_gate_w", + .source_names = {"recurrent_block/rg_lru/input_gate/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + .concat_names = {"gr_gate_w", "gr2_gate_w"}, + }, + TensorInfo{ + .name = "gr2_gate_w", + .source_names = {"recurrent_block/rg_lru/a_gate/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "gr_gate_w", + .source_names = {"recurrent_block/rg_lru/gate/w"}, + .axes = {0, 2, 1}, + .shape = {2 * layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + }, + TensorInfo{ + .name = "gr1_gate_b", + .source_names = {"recurrent_block/rg_lru/input_gate/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .concat_names = {"gr_gate_b", "gr2_gate_b"}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr2_gate_b", + .source_names = {"recurrent_block/rg_lru/a_gate/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_gate_b", + .source_names = {"recurrent_block/rg_lru/input_gate/b"}, + .axes = {0, 1}, + .shape = {2 * layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_a", + .source_names = {"recurrent_block/rg_lru/a_param"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + .scaled_softplus = true, + }, + + TensorInfo{ + .name = "gating_ein", + .source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum", + "mlp_block/ffw_up/w"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {2, layer_config.ff_hidden_dim, config.model_dim}, + }, + TensorInfo{ + .name = "gating1_w", + .source_names = {"none"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {layer_config.ff_hidden_dim, config.model_dim}, + }, + TensorInfo{ + .name = "gating2_w", + .source_names = {"none"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {layer_config.ff_hidden_dim, config.model_dim}, + }, + TensorInfo{ + .name = "linear_w", + .source_names = {"mlp/linear/w", "mlp/linear", + "mlp_block/ffw_down/kernel"}, + .axes = {1, 0}, + .shape = {config.model_dim, layer_config.ff_hidden_dim}, + }, + TensorInfo{ + .name = "pre_att_ns", + .source_names = {"pre_attention_norm/scale", + "temporal_pre_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "pre_ff_ns", + .source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "post_att_ns", + .source_names = {"post_attention_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "post_ff_ns", + .source_names = {"post_ffw_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ffw_gat_b", + .source_names = {"mlp_block/ffw_up/b"}, + .axes = {0}, + .shape = {2 * layer_config.ff_hidden_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "ffw_out_b", + .source_names = {"mlp_block/ffw_down/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }, + }; + if (reshape_att) { + tensors.push_back(TensorInfo{ + .name = "att_w", + .source_names = {"attn/attn_vec_einsum/w", + "attention_block/proj_final/kernel"}, + .axes = {2, 0, 1}, + .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, + .cols_take_extra_dims = true, + }); + tensors.push_back(TensorInfo{ + .name = "att_ein", + .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, + }); + } else { + tensors.push_back(TensorInfo{ + .name = "att_ein", + .source_names = {"attn/attn_vec_einsum/w", + "attention_block/proj_final/kernel"}, + .preshape = {layer_config.heads, layer_config.qkv_dim, + config.model_dim}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, + }); + tensors.push_back(TensorInfo{ + .name = "att_w", + .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, + .cols_take_extra_dims = true, + }); + } + return tensors; +} + +} // namespace + +TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx, + int img_layer_idx, bool reshape_att) + : config_(config), + llm_layer_idx_(llm_layer_idx), + img_layer_idx_(img_layer_idx) { + int layer_idx = std::max(llm_layer_idx_, img_layer_idx_); + std::string suffix; + if (layer_idx >= 0) { + suffix = "_" + std::to_string(layer_idx); + } + if (llm_layer_idx < 0 && img_layer_idx < 0) { + tensors_ = ModelTensors(config); + } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && + img_layer_idx < config.vit_layer_configs.size()) { + const auto& layer_config = config.vit_layer_configs[img_layer_idx]; + tensors_ = ImageLayerTensors(config, layer_config); + } else if (0 <= llm_layer_idx && + llm_layer_idx < config.layer_configs.size()) { + const auto& layer_config = config.layer_configs[llm_layer_idx]; + tensors_ = LLMLayerTensors(config, layer_config, reshape_att); + } + for (size_t i = 0; i < tensors_.size(); ++i) { + std::string key = tensors_[i].name + suffix; + name_map_.insert({key, i}); + } +} + +TensorInfo TensorIndex::GetTensorInfo(const std::string& path) const { + for (const auto& tensor : tensors_) { + for (const auto& source_name : tensor.source_names) { + auto pos = path.rfind(source_name); + if (pos != std::string::npos && path.size() == pos + source_name.size()) + return tensor; + } + } + return TensorInfo(); +} + +const TensorInfo* TensorIndex::FindName(const std::string& name) const { + std::string name_to_find = name; + if (!std::isdigit(name[name.size() - 1])) { + if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) { + name_to_find = name + "_" + std::to_string(img_layer_idx_); + } else if (llm_layer_idx_ >= 0) { + name_to_find = name + "_" + std::to_string(llm_layer_idx_); + } + } + auto it = name_map_.find(name_to_find); + if (it == name_map_.end()) { + return nullptr; + } + return &tensors_[it->second]; +} + +} // namespace gcpp diff --git a/gemma/tensor_index.h b/gemma/tensor_index.h new file mode 100644 index 0000000..a1acfd6 --- /dev/null +++ b/gemma/tensor_index.h @@ -0,0 +1,91 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ + +#include + +#include +#include +#include + +#include "compression/shared.h" +#include "gemma/configs.h" + +namespace gcpp { + +// Universal tensor information. Holds enough information to construct a +// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the +// tensor from the python model with necessary transpose/reshape info. +struct TensorInfo { + // The name of the tensor in the sbs file + std::string name; + // Strings to match to the end of the name of the tensor in the python model. + std::vector source_names; + // Initial reshape shape. Use only as a last resort when input may have + // dimensions combined that need to be split before the transpose, as it + // defeats the post-transpose shape checking. Normally empty. + std::vector preshape; + // Transpose axes arg. If the input tensor has more dimensions than axes, + // then leading dimensions are collapsed until the number of axes matches. + std::vector axes; + // Expected final shape of the tensor after reshape/transpose. + // Note that this is the shape of the tensor during export, + // not the shape of the tensor in the sbs file, as the sbs file + // is restricted to 2D tensors. With few exceptions, the sbs file + // tensor rows gather all the excess dimensions. See cols_take_extra_dims. + std::vector shape; + // List of names to concatenate with, used only if multiple tensors are + // concatenated into one. The first tensor in the concatenation should have + // concat names thus: The first name is the name of the result, and the + // tensors with the remaining names are concatenated after this. + // The remaining tensors to be concatenated should have just a single + // empty string in concat_names to indicate that they have been consumed. + std::vector concat_names; + // Axis at which to concatenate. + size_t concat_axis = 0; + // The minimum compression weight type for this tensor. The default is + // kNUQ, which provides maximum compression. Other values such as kBF16 + // or kF32 can be used to limit the compression to a specific type. + Type min_size = Type::kNUQ; + // Whether to apply scaled softplus to the data. + bool scaled_softplus = false; + // Whether the columns or the rows take any extra dimensions. + // If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30]. + // If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30]. + bool cols_take_extra_dims = false; +}; + +// Universal index of tensor information, which can be built for a specific +// layer_idx. +class TensorIndex { + public: + // Builds a list of TensorInfo for the given layer_idx. + // If reshape_att is true, the attn_vec_einsum tensor is reshaped. + TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx, + bool reshape_att); + ~TensorIndex() = default; + + // Returns the TensorInfo whose source_name matches the end of the given path, + // or an empty TensorInfo if not found. + // NOTE: that the returned TensorInfo is a copy, so that the source + // TensorIndex can be destroyed without affecting the returned TensorInfo. + TensorInfo GetTensorInfo(const std::string& path) const; + + // Returns the TensorInfo for the given tensor name, for concise construction + // of ModelWeightsPtrs/LayerWeightsPtrs. + const TensorInfo* FindName(const std::string& name) const; + + private: + // Config that was used to build the tensor index. + const ModelConfig& config_; + // Layer that this tensor index is for - either LLM or image. + int llm_layer_idx_; + int img_layer_idx_; + // List of tensor information for this layer. + std::vector tensors_; + // Map from tensor name to index in tensors_. + std::unordered_map name_map_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ diff --git a/gemma/tensor_index_test.cc b/gemma/tensor_index_test.cc new file mode 100644 index 0000000..7fd1268 --- /dev/null +++ b/gemma/tensor_index_test.cc @@ -0,0 +1,73 @@ +#include "gemma/tensor_index.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "compression/compress.h" +#include "compression/shared.h" +#include "gemma/configs.h" +#include "gemma/weights.h" +#include "util/basics.h" +#include "hwy/aligned_allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { +namespace { + +// Tests that each tensor in the model can be found by exactly one TensorIndex, +// and that the TensorIndex returns the correct shape and name for the tensor, +// for all models. +TEST(TensorIndexTest, FindName) { + hwy::ThreadPool pool(4); + for (Model model : kAllModels) { + fprintf(stderr, "Testing model %d\n", static_cast(model)); + ModelConfig config = ConfigFromModel(model); + std::vector tensor_indexes; + tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, + /*img_layer_idx=*/-1, + /*split_and_reshape=*/false); + for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size(); + ++llm_layer_idx) { + tensor_indexes.emplace_back(config, static_cast(llm_layer_idx), + /*img_layer_idx=*/-1, + /*split_and_reshape=*/false); + } + for (size_t img_layer_idx = 0; + img_layer_idx < config.vit_layer_configs.size(); + ++img_layer_idx) { + tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, + static_cast(img_layer_idx), + /*split_and_reshape=*/false); + } + // For each tensor in any model, exactly one TensorIndex should find it. + ModelWeightsPtrs weights(config, pool); + ModelWeightsPtrs::ForEachTensor( + {&weights}, ForEachType::kInitNoToc, + [&tensor_indexes](const char* name, hwy::Span tensors) { + int num_found = 0; + const MatPtr& tensor = *tensors[0]; + for (const auto& tensor_index : tensor_indexes) { + // Skip the type marker prefix, but we want the layer index suffix. + std::string name_to_find(name + 1, strlen(name) - 1); + const TensorInfo* info = tensor_index.FindName(name_to_find); + if (info != nullptr) { + // Test that the MatPtr can be constructed from the TensorInfo, + // and that the dimensions match. + MatPtrT mat_ptr(tensor.Name(), tensor_index); + EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name; + EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name; + EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name; + ++num_found; + } + } + EXPECT_EQ(num_found, 1) << " for tensor " << name; + }); + } +} + +} // namespace +} // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index ce2df43..b9acf89 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -57,8 +57,8 @@ template struct LayerWeightsPtrs { // Large data is constructed separately. explicit LayerWeightsPtrs(const LayerConfig& config) - : attn_vec_einsum_w("att_ein", config.model_dim, - config.heads * config.qkv_dim), + : attn_vec_einsum_w("att_ein", config.heads * config.model_dim, + config.qkv_dim), qkv_einsum_w("qkv_ein", (config.heads + 2 * config.kv_heads) * config.qkv_dim, config.model_dim), @@ -86,8 +86,8 @@ struct LayerWeightsPtrs { .gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2}, .a = {"gr_a", 1, config.griffin_dim}}), // MultiHeadDotProductAttention. - vit({.attn_out_w = {"attn_out_w", config.heads * config.qkv_dim, - config.model_dim}, + vit({.attn_out_w = {"attn_out_w", config.model_dim, + config.heads * config.qkv_dim}, .attn_out_b = {"attn_out_b", 1, config.model_dim}, .qkv_einsum_w = {"qkv_ein_w", (config.heads + 2 * config.kv_heads) * @@ -349,9 +349,8 @@ struct ModelWeightsPtrs { vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), - vit_img_embedding_kernel("img_emb_kernel", - config.patch_width * config.patch_width * 3, - config.vit_model_dim), + vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim, + config.patch_width * config.patch_width * 3), vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim), vit_img_head_kernel("img_head_kernel", config.model_dim,