diff --git a/compression/stats.cc b/compression/stats.cc index bfc7cbf..2013422 100644 --- a/compression/stats.cc +++ b/compression/stats.cc @@ -23,6 +23,8 @@ #include "hwy/base.h" // HWY_ASSERT +namespace gcpp { + void Stats::Assimilate(const Stats& other) { const int64_t total_n = n_ + other.n_; if (total_n == 0) return; // Nothing to do; prevents div by zero. @@ -115,3 +117,5 @@ std::string Stats::ToString(int exclude) const { HWY_ASSERT(pos < sizeof(buf)); return buf; } + +} // namespace gcpp diff --git a/compression/stats.h b/compression/stats.h index 1f0d262..01f97f7 100644 --- a/compression/stats.h +++ b/compression/stats.h @@ -25,6 +25,8 @@ #include "hwy/base.h" // HWY_ASSERT +namespace gcpp { + // Thread-compatible. template class Bins { @@ -187,4 +189,6 @@ class Stats { double m4_; }; +} // namespace gcpp + #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_ diff --git a/configs.h b/configs.h index 278d5ea..ebe6220 100644 --- a/configs.h +++ b/configs.h @@ -18,38 +18,34 @@ #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ -#include +#include namespace gcpp { static constexpr size_t kSeqLen = 7168; struct ConfigGemma7B { - // NOLINTBEGIN(google3-readability-class-member-naming) - static constexpr int seq_len = kSeqLen; - static constexpr int vocab_size = 256128; - static constexpr int n_layers = 28; - static constexpr int dim_model = 3072; - static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576 - static constexpr int n_heads = 16; - static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA - static constexpr int dim_qkv = 256; // query size == key size == value size - static constexpr int top_k = 1; - // NOLINTEND(google3-readability-class-member-naming) + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256128; + static constexpr int kLayers = 28; + static constexpr int kModelDim = 3072; + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 + static constexpr int kHeads = 16; + static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = 1; }; struct ConfigGemma2B { - // NOLINTBEGIN(google3-readability-class-member-naming) - static constexpr int seq_len = kSeqLen; - static constexpr int vocab_size = 256128; - static constexpr int n_layers = 18; - static constexpr int dim_model = 2048; - static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384 - static constexpr int n_heads = 8; - static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support - static constexpr int dim_qkv = 256; // query size == key size == value size - static constexpr int top_k = 1; - // NOLINTEND(google3-readability-class-member-naming) + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256128; + static constexpr int kLayers = 18; + static constexpr int kModelDim = 2048; + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 + static constexpr int kHeads = 8; + static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = 1; }; } // namespace gcpp diff --git a/gemma.cc b/gemma.cc index 5cea7d5..d8e1ca6 100644 --- a/gemma.cc +++ b/gemma.cc @@ -69,37 +69,34 @@ namespace gcpp { template struct Layer { Layer() = default; - // NOLINTBEGIN(google3-readability-class-member-naming) - static constexpr size_t n_heads = TConfig::n_heads; - static constexpr size_t dim_model = TConfig::dim_model; - static constexpr size_t dim_qkv = TConfig::dim_qkv; - static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; - static constexpr size_t size_attn_vec_einsum_w = - n_heads * dim_qkv * dim_model; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; // 3x for (query, key, value) - static constexpr size_t size_qkv_einsum_w = 3 * n_heads * dim_qkv * dim_model; + static constexpr size_t kQKVEinsumWSize = 3 * kHeads * kQKVDim * kModelDim; // 2x for (gelu gating vector, gated vector) - static constexpr size_t size_gating_einsum_w = 2 * dim_ffw_hidden * dim_model; - static constexpr size_t size_linear_w = dim_model * dim_ffw_hidden; - std::array attn_vec_einsum_w; - std::array qkv_einsum_w; - std::array gating_einsum_w; - std::array linear_w; - std::array pre_attention_norm_scale; - std::array pre_ffw_norm_scale; - // NOLINTEND(google3-readability-class-member-naming) + static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; + + std::array attn_vec_einsum_w; + std::array qkv_einsum_w; + std::array gating_einsum_w; + std::array linear_w; + std::array pre_attention_norm_scale; + std::array pre_ffw_norm_scale; }; template struct Weights { Weights() = default; - hwy::AlignedUniquePtr[]> layers; // n_layers + hwy::AlignedUniquePtr[]> layers; // kLayers - std::array + std::array embedder_input_embedding; - std::array final_norm_scale; + std::array final_norm_scale; }; // Only called if cached loading fails. @@ -109,7 +106,7 @@ hwy::AlignedUniquePtr> LoadWeights(const Path& checkpoint) { using TWeights = Weights; hwy::AlignedUniquePtr weights = hwy::MakeUniqueAligned(); weights->layers = - hwy::MakeUniqueAlignedArray>(TConfig::n_layers); + hwy::MakeUniqueAlignedArray>(TConfig::kLayers); FILE* fptr; fptr = fopen(checkpoint.path.c_str(), "rb"); @@ -122,7 +119,7 @@ hwy::AlignedUniquePtr> LoadWeights(const Path& checkpoint) { sizeof(weights->embedder_input_embedding), 1, fptr); ok &= 1 == fread(&(weights->final_norm_scale), sizeof(weights->final_norm_scale), 1, fptr); - for (size_t layer = 0; layer < TConfig::n_layers; ++layer) { + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { Layer* layer_view = &weights->layers[layer]; ok &= 1 == fread(&layer_view->attn_vec_einsum_w, sizeof(layer_view->attn_vec_einsum_w), 1, fptr); @@ -151,19 +148,17 @@ struct CompressedLayer { using TLayer = gcpp::Layer; - // # NOLINTBEGIN(google3-readability-class-member-naming) - static constexpr size_t dim_model = TConfig::dim_model; - static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; - // NOLINTEND(google3-readability-class-member-naming) + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; // Compressed Parameters // We don't yet have an RMSNorm that accepts all WeightT. - CompressedArray c_pre_attention_norm_scale; - CompressedArray c_pre_ffw_norm_scale; - CompressedArray c_gating_einsum_w; - CompressedArray c_linear_w; - CompressedArray c_qkv_einsum_w; - CompressedArray c_attn_vec_einsum_w; + CompressedArray c_pre_attention_norm_scale; + CompressedArray c_pre_ffw_norm_scale; + CompressedArray c_gating_einsum_w; + CompressedArray c_linear_w; + CompressedArray c_qkv_einsum_w; + CompressedArray c_attn_vec_einsum_w; }; // Array instead of single large allocation for parallel mem init. Split out of @@ -172,23 +167,23 @@ struct CompressedLayer { template struct CompressedLayerPointers { explicit CompressedLayerPointers(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::n_layers, [this](uint64_t task, size_t /*thread*/) { + pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { this->c_layers[task] = hwy::AllocateAligned>(1); }); } using CLayer = CompressedLayer; - std::array, TConfig::n_layers> c_layers; + std::array, TConfig::kLayers> c_layers; }; template struct CompressedWeights { // No ctor/dtor, allocated via AllocateAligned. - CompressedArray + CompressedArray c_embedder_input_embedding; - CompressedArray c_final_norm_scale; + CompressedArray c_final_norm_scale; // Must be last so that the other arrays remain aligned. CompressedLayerPointers c_layer_ptrs; @@ -202,38 +197,35 @@ struct CompressedWeights { }; // Aligned. -template +template struct Activations { - // # NOLINTBEGIN(google3-readability-class-member-naming) - static constexpr size_t batch_size = BatchSize; + static constexpr size_t kBatchSize = TBatchSize; using LayerConfig = Layer; - static constexpr size_t dim_model = TConfig::dim_model; - static constexpr size_t dim_qkv = TConfig::dim_qkv; - static constexpr size_t n_heads = TConfig::n_heads; - static constexpr size_t n_kv_heads = TConfig::n_kv_heads; - static constexpr size_t size_cache_pos = - TConfig::n_layers * n_kv_heads * dim_qkv; - static constexpr size_t size_cache_layer = n_kv_heads * dim_qkv; - // NOLINTEND(google3-readability-class-member-naming) + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kCachePosSize = + TConfig::kLayers * kKVHeads * kQKVDim; + static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; - std::array x; // input - std::array pre_att_rms_out; - std::array q; // query vector - std::array - att; // attention vector - std::array - att_out; // attention output - std::array + std::array x; // input + std::array pre_att_rms_out; + std::array q; // query vector + std::array + att; // attention vector + std::array att_out; // attention output + std::array att_post1; // attention output after linear transformation, per head - std::array + std::array att_post2; // accumulation of attention outputs over heads - std::array bf_pre_ffw_rms_out; - std::array ffw_hidden; + std::array bf_pre_ffw_rms_out; + std::array ffw_hidden; // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. - // std::array + // std::array // bf_ffw_hidden; - std::array ffw_out; - std::array logits; + std::array ffw_out; + std::array logits; }; // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we @@ -288,45 +280,45 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template +template HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, - Activations& activations, + Activations& activations, const CompressedLayer* c_layer, KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); const size_t pos = batch_start + batch_idx; - HWY_DASSERT(batch_idx < batch_size); - static constexpr size_t dim_qkv = gcpp::Activations::dim_qkv; - static constexpr size_t size_cache_pos = - gcpp::Activations::size_cache_pos; - static constexpr size_t size_cache_layer = - gcpp::Activations::size_cache_layer; - static constexpr size_t dim_model = - gcpp::Activations::dim_model; - static constexpr size_t n_heads = TConfig::n_heads; - const float kQueryScale = 1.0 / sqrtf(static_cast(dim_qkv)); + HWY_DASSERT(batch_idx < kBatchSize); + static constexpr size_t kQKVDim = gcpp::Activations::kQKVDim; + static constexpr size_t kCachePosSize = + gcpp::Activations::kCachePosSize; + static constexpr size_t kCacheLayerSize = + gcpp::Activations::kCacheLayerSize; + static constexpr size_t kModelDim = + gcpp::Activations::kModelDim; + static constexpr size_t kHeads = TConfig::kHeads; + const float kQueryScale = 1.0 / sqrtf(static_cast(kQKVDim)); - pool.Run(0, n_heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV const size_t head_offset = - 3 * dim_qkv * dim_model; // 3x for QKV dimensions - const size_t q_offset = head * head_offset + 0 * dim_qkv * dim_model; - const size_t k_offset = head * head_offset + 1 * dim_qkv * dim_model; - const size_t v_offset = head * head_offset + 2 * dim_qkv * dim_model; + 3 * kQKVDim * kModelDim; // 3x for QKV dimensions + const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim; + const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; + const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; float* HWY_RESTRICT q = - activations.q.data() + head * dim_qkv + batch_idx * n_heads * dim_qkv; + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - const size_t batch_offset = batch_idx * dim_model; + const size_t batch_offset = batch_idx * kModelDim; - MatVecLoop( + MatVecLoop( c_layer->c_qkv_einsum_w, q_offset, activations.pre_att_rms_out.data() + batch_offset, q); const size_t kv_offset = - pos * size_cache_pos + layer * size_cache_layer + head * dim_qkv; + pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - TwoOfsMatVecLoop( + TwoOfsMatVecLoop( c_layer->c_qkv_einsum_w, k_offset, v_offset, activations.pre_att_rms_out.data() + batch_offset, kv_cache.key_cache.get() + kv_offset, @@ -334,119 +326,119 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, // Calculate scores float* HWY_RESTRICT head_att = activations.att.data() + - head * TConfig::seq_len + - batch_idx * n_heads * dim_qkv; + head * TConfig::kSeqLen + + batch_idx * kHeads * kQKVDim; - Rope(q, dim_qkv, pos); - Rope(kv_cache.key_cache.get() + kv_offset, dim_qkv, pos); - MulByConst(kQueryScale, q, dim_qkv); + Rope(q, kQKVDim, pos); + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv; + pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; - const float score = Dot(q, k2, dim_qkv); + const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; } Softmax(head_att, pos + 1); // Weighted summation - float* HWY_RESTRICT att_out = activations.att_out.data() + head * dim_qkv + - batch_idx * n_heads * dim_qkv; - hwy::ZeroBytes(att_out, dim_qkv * sizeof(*att_out)); + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + + batch_idx * kHeads * kQKVDim; + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv; + pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; - MulByConstAndAdd(head_att[pos2], v2, att_out, dim_qkv); + MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } - // linear projection from dim_qkv back to dim_model, sum projections + // linear projection from kQKVDim back to kModelDim, sum projections // across heads float* HWY_RESTRICT head_out = head == 0 - ? activations.att_post2.data() + batch_idx * dim_model - : activations.att_post1.data() + head * batch_size * dim_model; - MatVecLoop(c_layer->c_attn_vec_einsum_w, - head * dim_model * dim_qkv, att_out, + ? activations.att_post2.data() + batch_idx * kModelDim + : activations.att_post1.data() + head * kBatchSize * kModelDim; + MatVecLoop(c_layer->c_attn_vec_einsum_w, + head * kModelDim * kQKVDim, att_out, head_out); }); // accumulate output across all heads into att_post2. head 0 already wrote // directly to att_post2. - for (size_t head = 1; head < n_heads; ++head) { - AddFrom(activations.att_post1.data() + head * batch_size * dim_model, - activations.att_post2.data() + batch_idx * dim_model, dim_model); + for (size_t head = 1; head < kHeads; ++head) { + AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim, + activations.att_post2.data() + batch_idx * kModelDim, kModelDim); } } -template -HWY_NOINLINE void FFW(Activations& activations, +template +HWY_NOINLINE void FFW(Activations& activations, size_t batch_idx, const CompressedLayer* c_layer, hwy::ThreadPool& pool) { - HWY_DASSERT(batch_idx < batch_size); - static constexpr size_t dim_model = TConfig::dim_model; - static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; - const size_t hidden_offset = batch_idx * dim_ffw_hidden * 2; + HWY_DASSERT(batch_idx < kBatchSize); + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; { PROFILER_ZONE("Gen.FFW.GatedGELU"); const hwy::bfloat16_t* HWY_RESTRICT vec = - activations.bf_pre_ffw_rms_out.data() + batch_idx * dim_model; + activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT out_mul = out + dim_ffw_hidden; + float* HWY_RESTRICT out_mul = out + kFFHiddenDim; // Same matrix, first and second half of rows. Could fuse into one MatVec, // but separating them could help on NUMA e.g. multiple sockets. - MatVec(c_layer->c_gating_einsum_w, - dim_ffw_hidden * dim_model, vec, out_mul, - pool); + MatVec(c_layer->c_gating_einsum_w, + kFFHiddenDim * kModelDim, vec, out_mul, + pool); // Gate, will go through the nonlinearity. - MatVec(c_layer->c_gating_einsum_w, 0, vec, out, - pool); + MatVec(c_layer->c_gating_einsum_w, 0, vec, out, + pool); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - hn::Transform1(DF(), out, dim_ffw_hidden, out_mul, + hn::Transform1(DF(), out, kFFHiddenDim, out_mul, [](DF df, VF v, VF mul) HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); } PROFILER_ZONE("Gen.FFW\\GatedGELU"); - MatVec( + MatVec( c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset, - activations.ffw_out.data() + batch_idx * dim_model, pool); + activations.ffw_out.data() + batch_idx * kModelDim, pool); } -template +template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const CompressedWeights& c_weights, - Activations& activations, + Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); - static constexpr size_t dim_model = TConfig::dim_model; - static const float kEmbScaling = sqrtf(static_cast(dim_model)); + static constexpr size_t kModelDim = TConfig::kModelDim; + static const float kEmbScaling = sqrtf(static_cast(kModelDim)); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { const int token = tokens[token_idx]; - Decompress(c_weights.c_embedder_input_embedding, token * dim_model, - activations.x.data() + token_idx * dim_model, dim_model); - MulByConst(kEmbScaling, activations.x.data() + token_idx * dim_model, - dim_model); + Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, + kModelDim); }); - for (size_t layer = 0; layer < TConfig::n_layers; ++layer) { + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const CompressedLayer* c_layer = c_weights.CLayer(layer); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNorm(activations.x.data() + token_idx * dim_model, + RMSNorm(activations.x.data() + token_idx * kModelDim, c_layer->c_pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data() + token_idx * dim_model, - dim_model); - Attention(pos, token_idx, layer, activations, + activations.pre_att_rms_out.data() + token_idx * kModelDim, + kModelDim); + Attention(pos, token_idx, layer, activations, c_layer, kv_cache, pool); } @@ -454,22 +446,22 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { - AddFrom(activations.att_post2.data() + token_idx * dim_model, - activations.x.data() + token_idx * dim_model, dim_model); - RMSNorm(activations.x.data() + token_idx * dim_model, + AddFrom(activations.att_post2.data() + token_idx * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + RMSNorm(activations.x.data() + token_idx * kModelDim, c_layer->c_pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data() + token_idx * dim_model, - dim_model); - FFW(activations, token_idx, c_layer, inner_pool); - AddFrom(activations.ffw_out.data() + token_idx * dim_model, - activations.x.data() + token_idx * dim_model, dim_model); + activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, + kModelDim); + FFW(activations, token_idx, c_layer, inner_pool); + AddFrom(activations.ffw_out.data() + token_idx * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); }); } // foreach layer pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { RMSNormInplace(c_weights.c_final_norm_scale.data(), - activations.x.data() + token_idx * dim_model, dim_model); + activations.x.data() + token_idx * kModelDim, kModelDim); }); } @@ -479,29 +471,29 @@ void Transformer(int token, size_t pos, const CompressedWeights& c_weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { - static constexpr size_t n_layers = TConfig::n_layers; - static constexpr size_t dim_model = TConfig::dim_model; + static constexpr size_t kLayers = TConfig::kLayers; + static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(dim_model)); + static const float kEmbScaling = sqrtf(static_cast(kModelDim)); - Decompress(c_weights.c_embedder_input_embedding, token * dim_model, - activations.x.data(), dim_model); + Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, + activations.x.data(), kModelDim); - MulByConst(kEmbScaling, activations.x.data(), dim_model); + MulByConst(kEmbScaling, activations.x.data(), kModelDim); - for (size_t layer = 0; layer < n_layers; ++layer) { + for (size_t layer = 0; layer < kLayers; ++layer) { const CompressedLayer* c_layer = c_weights.CLayer(layer); RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), dim_model); + activations.pre_att_rms_out.data(), kModelDim); Attention(pos, 0, layer, activations, c_layer, kv_cache, pool); - AddFrom(activations.att_post2.data(), activations.x.data(), dim_model); + AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), dim_model); + activations.bf_pre_ffw_rms_out.data(), kModelDim); FFW(activations, /* batch_idx = */ 0, c_layer, pool); - AddFrom(activations.ffw_out.data(), activations.x.data(), dim_model); + AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); } RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(), - dim_model); + kModelDim); } template @@ -511,9 +503,9 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - static constexpr size_t dim_model = TConfig::dim_model; - static constexpr size_t vocab_size = TConfig::vocab_size; - static constexpr size_t top_k = TConfig::top_k; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kTopK = TConfig::kTopK; Activations& activations = *gemma.state.get(); Activations& prefill_activations = *gemma.prefill.get(); @@ -587,12 +579,12 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (pos_offset >= prompt.size()) { PROFILER_ZONE("Gen.Embedding"); // Generation phase - MatVec(c_weights.c_embedder_input_embedding, 0, + MatVec(c_weights.c_embedder_input_embedding, 0, final_activation, activations.logits.data(), pool); // Barrier: must have all logits so we can subtract max. - Softmax(activations.logits.data(), vocab_size); - token = SampleTopK(activations.logits.data(), vocab_size, gen, + Softmax(activations.logits.data(), kVocabSize); + token = SampleTopK(activations.logits.data(), kVocabSize, gen, args.temperature, accept_token); } if (!stream_token(token, activations.logits[token])) { @@ -643,7 +635,7 @@ void ForEachTensor(const Weights* weights, c_weights.c_final_norm_scale); char name[16]; - for (size_t layer_idx = 0; layer_idx < TConfig::n_layers; ++layer_idx) { + for (size_t layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { Layer* layer = weights ? &weights->layers[layer_idx] : nullptr; CompressedLayer* c_layer = c_weights.CLayer(layer_idx); @@ -729,10 +721,10 @@ HWY_EXPORT(GetCompressedWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { +KVCache CreateKVCache(size_t size_cache_pos, size_t kSeqLen) { KVCache kv_cache = {}; - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.key_cache = hwy::AllocateAligned(kSeqLen * size_cache_pos); + kv_cache.value_cache = hwy::AllocateAligned(kSeqLen * size_cache_pos); return kv_cache; } @@ -743,8 +735,8 @@ GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), kv_cache( - CreateKVCache(Config::n_layers * Config::n_kv_heads * Config::dim_qkv, - Config::seq_len)) { + CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen)) { PROFILER_ZONE("Startup.tokenizer"); HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); diff --git a/gemma.h b/gemma.h index 9647a6c..67b7f85 100644 --- a/gemma.h +++ b/gemma.h @@ -51,9 +51,9 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv + key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr - value_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv + value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim }; // Model variants: see configs.h for details.