mirror of https://github.com/google/gemma.cpp.git
Removed duplicated tensor sizes from weights.h by changing the constructor used for MatPtrT
PiperOrigin-RevId: 705085054
This commit is contained in:
parent
aed17396be
commit
6254f2e5ca
|
|
@ -269,7 +269,6 @@ cc_test(
|
|||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -411,12 +411,14 @@ TEST(BackPropTest, LayerVJP) {
|
|||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
ModelConfig config = TestConfig();
|
||||
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
|
||||
/*reshape_att=*/false);
|
||||
const size_t kOutputSize = config.seq_len * config.model_dim;
|
||||
LayerWeightsPtrs<T> weights(config.layer_configs[0]);
|
||||
LayerWeightsPtrs<T> grad(config.layer_configs[0]);
|
||||
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
|
||||
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
|
||||
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0]);
|
||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
||||
MatStorageT<T> y("y", kOutputSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutputSize, 1);
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ template <typename T>
|
|||
class WeightsWrapper {
|
||||
public:
|
||||
explicit WeightsWrapper(const ModelConfig& config)
|
||||
: pool_(0), weights_(config, pool_) {
|
||||
: pool_(0), weights_(config) {
|
||||
weights_.Allocate(data_, pool_);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -219,7 +219,10 @@ class MatPtrT : public MatPtr {
|
|||
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||
MatPtrT(const std::string& name, const TensorInfo* tensor)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
||||
HWY_ASSERT(tensor != nullptr);
|
||||
if (tensor == nullptr) {
|
||||
cols_ = 0;
|
||||
rows_ = 0;
|
||||
} else {
|
||||
cols_ = tensor->shape.back();
|
||||
rows_ = 1;
|
||||
if (tensor->cols_take_extra_dims) {
|
||||
|
|
@ -235,6 +238,9 @@ class MatPtrT : public MatPtr {
|
|||
}
|
||||
}
|
||||
}
|
||||
stride_ = cols_;
|
||||
num_elements_ = rows_ * cols_;
|
||||
}
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
|
|
|
|||
|
|
@ -165,9 +165,9 @@ void CompressWeights(const Path& weights_path,
|
|||
compressed_weights_path.path.c_str());
|
||||
ModelConfig config = ConfigFromModel(model_type);
|
||||
std::vector<MatStorage> model_storage;
|
||||
ModelWeightsPtrs<T> c_weights(config, pool);
|
||||
ModelWeightsPtrs<T> c_weights(config);
|
||||
c_weights.Allocate(model_storage, pool);
|
||||
ModelWeightsPtrs<float> uc_weights(config, pool);
|
||||
ModelWeightsPtrs<float> uc_weights(config);
|
||||
uc_weights.Allocate(model_storage, pool);
|
||||
// Get uncompressed weights, compress, and store.
|
||||
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@
|
|||
#include "gemma/weights.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
|
@ -22,7 +21,6 @@ namespace {
|
|||
// and that the TensorIndex returns the correct shape and name for the tensor,
|
||||
// for all models.
|
||||
TEST(TensorIndexTest, FindName) {
|
||||
hwy::ThreadPool pool(4);
|
||||
for (Model model : kAllModels) {
|
||||
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
|
||||
ModelConfig config = ConfigFromModel(model);
|
||||
|
|
@ -44,7 +42,7 @@ TEST(TensorIndexTest, FindName) {
|
|||
/*split_and_reshape=*/false);
|
||||
}
|
||||
// For each tensor in any model, exactly one TensorIndex should find it.
|
||||
ModelWeightsPtrs<SfpStream> weights(config, pool);
|
||||
ModelWeightsPtrs<SfpStream> weights(config);
|
||||
ModelWeightsPtrs<SfpStream>::ForEachTensor(
|
||||
{&weights}, ForEachType::kInitNoToc,
|
||||
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
|
|
|
|||
|
|
@ -186,18 +186,18 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
|
|||
hwy::ThreadPool& pool) {
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_, pool);
|
||||
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_, pool);
|
||||
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
sfp_weights_ =
|
||||
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_, pool);
|
||||
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
nuq_weights_ =
|
||||
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_, pool);
|
||||
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
|
|
|
|||
154
gemma/weights.h
154
gemma/weights.h
|
|
@ -30,6 +30,7 @@
|
|||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -56,73 +57,48 @@ enum class ForEachType {
|
|||
template <class Weight>
|
||||
struct LayerWeightsPtrs {
|
||||
// Large data is constructed separately.
|
||||
explicit LayerWeightsPtrs(const LayerConfig& config)
|
||||
: attn_vec_einsum_w("att_ein", config.heads * config.model_dim,
|
||||
config.qkv_dim),
|
||||
qkv_einsum_w("qkv_ein",
|
||||
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
|
||||
config.model_dim),
|
||||
qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim,
|
||||
config.model_dim),
|
||||
qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim,
|
||||
config.model_dim),
|
||||
attention_output_biases(
|
||||
"attn_ob", 1,
|
||||
config.softmax_attn_output_biases ? config.model_dim : 0),
|
||||
griffin(
|
||||
{.linear_x_w = {"gr_lin_x_w", config.griffin_dim,
|
||||
config.griffin_dim},
|
||||
.linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim},
|
||||
.linear_y_w = {"gr_lin_y_w", config.griffin_dim,
|
||||
config.griffin_dim},
|
||||
.linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim},
|
||||
.linear_out_w = {"gr_lin_out_w", config.griffin_dim,
|
||||
config.griffin_dim},
|
||||
.linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim},
|
||||
.conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim},
|
||||
.conv_biases = {"gr_conv_b", 1, config.griffin_dim},
|
||||
.gate_w = {"gr_gate_w", 2 * config.griffin_dim,
|
||||
config.griffin_dim / config.heads},
|
||||
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
|
||||
.a = {"gr_a", 1, config.griffin_dim}}),
|
||||
explicit LayerWeightsPtrs(const LayerConfig& config,
|
||||
const TensorIndex& tensor_index)
|
||||
: attn_vec_einsum_w("att_ein", tensor_index),
|
||||
qkv_einsum_w("qkv_ein", tensor_index),
|
||||
qkv_einsum_w1("qkv1_w", tensor_index),
|
||||
qkv_einsum_w2("qkv2_w", tensor_index),
|
||||
attention_output_biases("attn_ob", tensor_index),
|
||||
griffin({.linear_x_w = {"gr_lin_x_w", tensor_index},
|
||||
.linear_x_biases = {"gr_lin_x_b", tensor_index},
|
||||
.linear_y_w = {"gr_lin_y_w", tensor_index},
|
||||
.linear_y_biases = {"gr_lin_y_b", tensor_index},
|
||||
.linear_out_w = {"gr_lin_out_w", tensor_index},
|
||||
.linear_out_biases = {"gr_lin_out_b", tensor_index},
|
||||
.conv_w = {"gr_conv_w", tensor_index},
|
||||
.conv_biases = {"gr_conv_b", tensor_index},
|
||||
.gate_w = {"gr_gate_w", tensor_index},
|
||||
.gate_biases = {"gr_gate_b", tensor_index},
|
||||
.a = {"gr_a", tensor_index}}),
|
||||
// MultiHeadDotProductAttention.
|
||||
vit({.attn_out_w = {"attn_out_w", config.model_dim,
|
||||
config.heads * config.qkv_dim},
|
||||
.attn_out_b = {"attn_out_b", 1, config.model_dim},
|
||||
.qkv_einsum_w = {"qkv_ein_w",
|
||||
(config.heads + 2 * config.kv_heads) *
|
||||
config.qkv_dim,
|
||||
config.model_dim},
|
||||
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
|
||||
config.qkv_dim},
|
||||
.linear_0_w = {"linear_0_w", config.ff_hidden_dim,
|
||||
config.model_dim},
|
||||
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
|
||||
.linear_1_w = {"linear_1_w", config.model_dim,
|
||||
config.ff_hidden_dim},
|
||||
.linear_1_b = {"linear_1_b", 1, config.model_dim},
|
||||
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
|
||||
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
|
||||
.layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim},
|
||||
.layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}),
|
||||
gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim,
|
||||
config.model_dim),
|
||||
gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim),
|
||||
gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim),
|
||||
linear_w("linear_w", config.model_dim, config.ff_hidden_dim),
|
||||
pre_attention_norm_scale("pre_att_ns", 1, config.model_dim),
|
||||
pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim),
|
||||
post_attention_norm_scale(
|
||||
"post_att_ns", 1,
|
||||
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
|
||||
post_ffw_norm_scale(
|
||||
"post_ff_ns", 1,
|
||||
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
|
||||
ffw_gating_biases("ffw_gat_b", 1,
|
||||
config.ff_biases ? 2 * config.ff_hidden_dim : 0),
|
||||
ffw_output_biases("ffw_out_b", 1,
|
||||
config.ff_biases ? config.model_dim : 0),
|
||||
att_weights("att_w", config.model_dim, config.heads * config.qkv_dim),
|
||||
vit({.attn_out_w = {"attn_out_w", tensor_index},
|
||||
.attn_out_b = {"attn_out_b", tensor_index},
|
||||
.qkv_einsum_w = {"qkv_ein_w", tensor_index},
|
||||
.qkv_einsum_b = {"qkv_ein_b", tensor_index},
|
||||
.linear_0_w = {"linear_0_w", tensor_index},
|
||||
.linear_0_b = {"linear_0_b", tensor_index},
|
||||
.linear_1_w = {"linear_1_w", tensor_index},
|
||||
.linear_1_b = {"linear_1_b", tensor_index},
|
||||
.layer_norm_0_bias = {"ln_0_bias", tensor_index},
|
||||
.layer_norm_0_scale = {"ln_0_scale", tensor_index},
|
||||
.layer_norm_1_bias = {"ln_1_bias", tensor_index},
|
||||
.layer_norm_1_scale = {"ln_1_scale", tensor_index}}),
|
||||
gating_einsum_w("gating_ein", tensor_index),
|
||||
gating_einsum_w1("gating1_w", tensor_index),
|
||||
gating_einsum_w2("gating2_w", tensor_index),
|
||||
linear_w("linear_w", tensor_index),
|
||||
pre_attention_norm_scale("pre_att_ns", tensor_index),
|
||||
pre_ffw_norm_scale("pre_ff_ns", tensor_index),
|
||||
post_attention_norm_scale("post_att_ns", tensor_index),
|
||||
post_ffw_norm_scale("post_ff_ns", tensor_index),
|
||||
ffw_gating_biases("ffw_gat_b", tensor_index),
|
||||
ffw_output_biases("ffw_out_b", tensor_index),
|
||||
att_weights("att_w", tensor_index),
|
||||
layer_config(config) {}
|
||||
~LayerWeightsPtrs() = default;
|
||||
|
||||
|
|
@ -342,28 +318,38 @@ struct LayerWeightsPtrs {
|
|||
|
||||
template <class Weight>
|
||||
struct ModelWeightsPtrs {
|
||||
ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool)
|
||||
: embedder_input_embedding("c_embedding", config.vocab_size,
|
||||
config.model_dim),
|
||||
final_norm_scale("c_final_norm", 1, config.model_dim),
|
||||
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
|
||||
config.patch_width * config.patch_width * 3),
|
||||
vit_img_pos_embedding("img_pos_emb", config.vit_seq_len,
|
||||
config.vit_model_dim),
|
||||
vit_img_head_bias("img_head_bias", 1, config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.model_dim,
|
||||
config.vit_model_dim),
|
||||
explicit ModelWeightsPtrs(const ModelConfig& config)
|
||||
: ModelWeightsPtrs(
|
||||
config,
|
||||
TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1,
|
||||
/*reshape_att=*/false)) {}
|
||||
ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index)
|
||||
: embedder_input_embedding("c_embedding", tensor_index),
|
||||
final_norm_scale("c_final_norm", tensor_index),
|
||||
vit_encoder_norm_bias("enc_norm_bias", tensor_index),
|
||||
vit_encoder_norm_scale("enc_norm_scale", tensor_index),
|
||||
vit_img_embedding_bias("img_emb_bias", tensor_index),
|
||||
vit_img_embedding_kernel("img_emb_kernel", tensor_index),
|
||||
vit_img_pos_embedding("img_pos_emb", tensor_index),
|
||||
vit_img_head_bias("img_head_bias", tensor_index),
|
||||
vit_img_head_kernel("img_head_kernel", tensor_index),
|
||||
scale_names(config.scale_names),
|
||||
weights_config(config) {
|
||||
c_layers.reserve(config.layer_configs.size());
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
|
||||
for (int index = 0; index < static_cast<int>(config.layer_configs.size());
|
||||
++index) {
|
||||
const auto& layer_config = config.layer_configs[index];
|
||||
TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1,
|
||||
/*reshape_att=*/false);
|
||||
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
|
||||
}
|
||||
for (const auto& layer_config : config.vit_layer_configs) {
|
||||
vit_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
|
||||
for (int index = 0;
|
||||
index < static_cast<int>(config.vit_layer_configs.size()); ++index) {
|
||||
const auto& layer_config = config.vit_layer_configs[index];
|
||||
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
|
||||
/*reshape_att=*/false);
|
||||
vit_layers.push_back(
|
||||
LayerWeightsPtrs<Weight>(layer_config, tensor_index));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue