Code update

PiperOrigin-RevId: 609394329
This commit is contained in:
The gemma_cpp Authors 2024-02-22 09:12:39 -08:00 committed by Dan Zheng
parent fb6f266db1
commit 587e80f276
5 changed files with 186 additions and 190 deletions

View File

@ -23,6 +23,8 @@
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
void Stats::Assimilate(const Stats& other) { void Stats::Assimilate(const Stats& other) {
const int64_t total_n = n_ + other.n_; const int64_t total_n = n_ + other.n_;
if (total_n == 0) return; // Nothing to do; prevents div by zero. if (total_n == 0) return; // Nothing to do; prevents div by zero.
@ -115,3 +117,5 @@ std::string Stats::ToString(int exclude) const {
HWY_ASSERT(pos < sizeof(buf)); HWY_ASSERT(pos < sizeof(buf));
return buf; return buf;
} }
} // namespace gcpp

View File

@ -25,6 +25,8 @@
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
// Thread-compatible. // Thread-compatible.
template <size_t N> template <size_t N>
class Bins { class Bins {
@ -187,4 +189,6 @@ class Stats {
double m4_; double m4_;
}; };
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_ #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_

View File

@ -18,38 +18,34 @@
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#include <cstddef> #include <stddef.h>
namespace gcpp { namespace gcpp {
static constexpr size_t kSeqLen = 7168; static constexpr size_t kSeqLen = 7168;
struct ConfigGemma7B { struct ConfigGemma7B {
// NOLINTBEGIN(google3-readability-class-member-naming) static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int seq_len = kSeqLen; static constexpr int kVocabSize = 256128;
static constexpr int vocab_size = 256128; static constexpr int kLayers = 28;
static constexpr int n_layers = 28; static constexpr int kModelDim = 3072;
static constexpr int dim_model = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16;
static constexpr int n_heads = 16; static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA
static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int dim_qkv = 256; // query size == key size == value size static constexpr int kTopK = 1;
static constexpr int top_k = 1;
// NOLINTEND(google3-readability-class-member-naming)
}; };
struct ConfigGemma2B { struct ConfigGemma2B {
// NOLINTBEGIN(google3-readability-class-member-naming) static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int seq_len = kSeqLen; static constexpr int kVocabSize = 256128;
static constexpr int vocab_size = 256128; static constexpr int kLayers = 18;
static constexpr int n_layers = 18; static constexpr int kModelDim = 2048;
static constexpr int dim_model = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8;
static constexpr int n_heads = 8; static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int dim_qkv = 256; // query size == key size == value size static constexpr int kTopK = 1;
static constexpr int top_k = 1;
// NOLINTEND(google3-readability-class-member-naming)
}; };
} // namespace gcpp } // namespace gcpp

322
gemma.cc
View File

@ -69,37 +69,34 @@ namespace gcpp {
template <class TConfig> template <class TConfig>
struct Layer { struct Layer {
Layer() = default; Layer() = default;
// NOLINTBEGIN(google3-readability-class-member-naming) static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t n_heads = TConfig::n_heads; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t dim_qkv = TConfig::dim_qkv; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t size_attn_vec_einsum_w =
n_heads * dim_qkv * dim_model;
// 3x for (query, key, value) // 3x for (query, key, value)
static constexpr size_t size_qkv_einsum_w = 3 * n_heads * dim_qkv * dim_model; static constexpr size_t kQKVEinsumWSize = 3 * kHeads * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector) // 2x for (gelu gating vector, gated vector)
static constexpr size_t size_gating_einsum_w = 2 * dim_ffw_hidden * dim_model; static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t size_linear_w = dim_model * dim_ffw_hidden;
std::array<float, size_attn_vec_einsum_w> attn_vec_einsum_w; std::array<float, kAttVecEinsumWSize> attn_vec_einsum_w;
std::array<float, size_qkv_einsum_w> qkv_einsum_w; std::array<float, kQKVEinsumWSize> qkv_einsum_w;
std::array<float, size_gating_einsum_w> gating_einsum_w; std::array<float, kGatingEinsumWSize> gating_einsum_w;
std::array<float, size_linear_w> linear_w; std::array<float, kModelDim * kFFHiddenDim> linear_w;
std::array<float, dim_model> pre_attention_norm_scale; std::array<float, kModelDim> pre_attention_norm_scale;
std::array<float, dim_model> pre_ffw_norm_scale; std::array<float, kModelDim> pre_ffw_norm_scale;
// NOLINTEND(google3-readability-class-member-naming)
}; };
template <class TConfig> template <class TConfig>
struct Weights { struct Weights {
Weights() = default; Weights() = default;
hwy::AlignedUniquePtr<Layer<TConfig>[]> layers; // n_layers hwy::AlignedUniquePtr<Layer<TConfig>[]> layers; // kLayers
std::array<float, TConfig::vocab_size * TConfig::dim_model> std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding; embedder_input_embedding;
std::array<float, TConfig::dim_model> final_norm_scale; std::array<float, TConfig::kModelDim> final_norm_scale;
}; };
// Only called if cached loading fails. // Only called if cached loading fails.
@ -109,7 +106,7 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
using TWeights = Weights<TConfig>; using TWeights = Weights<TConfig>;
hwy::AlignedUniquePtr<TWeights> weights = hwy::MakeUniqueAligned<TWeights>(); hwy::AlignedUniquePtr<TWeights> weights = hwy::MakeUniqueAligned<TWeights>();
weights->layers = weights->layers =
hwy::MakeUniqueAlignedArray<Layer<TConfig>>(TConfig::n_layers); hwy::MakeUniqueAlignedArray<Layer<TConfig>>(TConfig::kLayers);
FILE* fptr; FILE* fptr;
fptr = fopen(checkpoint.path.c_str(), "rb"); fptr = fopen(checkpoint.path.c_str(), "rb");
@ -122,7 +119,7 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
sizeof(weights->embedder_input_embedding), 1, fptr); sizeof(weights->embedder_input_embedding), 1, fptr);
ok &= 1 == fread(&(weights->final_norm_scale), ok &= 1 == fread(&(weights->final_norm_scale),
sizeof(weights->final_norm_scale), 1, fptr); sizeof(weights->final_norm_scale), 1, fptr);
for (size_t layer = 0; layer < TConfig::n_layers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
Layer<TConfig>* layer_view = &weights->layers[layer]; Layer<TConfig>* layer_view = &weights->layers[layer];
ok &= 1 == fread(&layer_view->attn_vec_einsum_w, ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
sizeof(layer_view->attn_vec_einsum_w), 1, fptr); sizeof(layer_view->attn_vec_einsum_w), 1, fptr);
@ -151,19 +148,17 @@ struct CompressedLayer {
using TLayer = gcpp::Layer<TConfig>; using TLayer = gcpp::Layer<TConfig>;
// # NOLINTBEGIN(google3-readability-class-member-naming) static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
// NOLINTEND(google3-readability-class-member-naming)
// Compressed Parameters // Compressed Parameters
// We don't yet have an RMSNorm that accepts all WeightT. // We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, dim_model> c_pre_attention_norm_scale; CompressedArray<hwy::bfloat16_t, kModelDim> c_pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, dim_model> c_pre_ffw_norm_scale; CompressedArray<hwy::bfloat16_t, kModelDim> c_pre_ffw_norm_scale;
CompressedArray<WeightT, TLayer::size_gating_einsum_w> c_gating_einsum_w; CompressedArray<WeightT, TLayer::kGatingEinsumWSize> c_gating_einsum_w;
CompressedArray<WeightT, dim_model * dim_ffw_hidden> c_linear_w; CompressedArray<WeightT, kModelDim * kFFHiddenDim> c_linear_w;
CompressedArray<WeightT, TLayer::size_qkv_einsum_w> c_qkv_einsum_w; CompressedArray<WeightT, TLayer::kQKVEinsumWSize> c_qkv_einsum_w;
CompressedArray<WeightT, TLayer::size_attn_vec_einsum_w> c_attn_vec_einsum_w; CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> c_attn_vec_einsum_w;
}; };
// Array instead of single large allocation for parallel mem init. Split out of // Array instead of single large allocation for parallel mem init. Split out of
@ -172,23 +167,23 @@ struct CompressedLayer {
template <class TConfig> template <class TConfig>
struct CompressedLayerPointers { struct CompressedLayerPointers {
explicit CompressedLayerPointers(hwy::ThreadPool& pool) { explicit CompressedLayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::n_layers, [this](uint64_t task, size_t /*thread*/) { pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1); this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
}); });
} }
using CLayer = CompressedLayer<TConfig>; using CLayer = CompressedLayer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::n_layers> c_layers; std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
}; };
template <class TConfig> template <class TConfig>
struct CompressedWeights { struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
CompressedArray<EmbedderInputT, TConfig::vocab_size * TConfig::dim_model> CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
c_embedder_input_embedding; c_embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::dim_model> c_final_norm_scale; CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> c_final_norm_scale;
// Must be last so that the other arrays remain aligned. // Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs; CompressedLayerPointers<TConfig> c_layer_ptrs;
@ -202,38 +197,35 @@ struct CompressedWeights {
}; };
// Aligned. // Aligned.
template <class TConfig, size_t BatchSize> template <class TConfig, size_t TBatchSize>
struct Activations { struct Activations {
// # NOLINTBEGIN(google3-readability-class-member-naming) static constexpr size_t kBatchSize = TBatchSize;
static constexpr size_t batch_size = BatchSize;
using LayerConfig = Layer<TConfig>; using LayerConfig = Layer<TConfig>;
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t dim_qkv = TConfig::dim_qkv; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t n_heads = TConfig::n_heads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t n_kv_heads = TConfig::n_kv_heads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t size_cache_pos = static constexpr size_t kCachePosSize =
TConfig::n_layers * n_kv_heads * dim_qkv; TConfig::kLayers * kKVHeads * kQKVDim;
static constexpr size_t size_cache_layer = n_kv_heads * dim_qkv; static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
// NOLINTEND(google3-readability-class-member-naming)
std::array<float, batch_size * dim_model> x; // input std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, batch_size * dim_model> pre_att_rms_out; std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, batch_size * n_heads * dim_qkv> q; // query vector std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector
std::array<float, batch_size * n_heads * TConfig::seq_len> std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
att; // attention vector att; // attention vector
std::array<float, batch_size * n_heads * dim_qkv> std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
att_out; // attention output std::array<float, kHeads * kBatchSize * kModelDim>
std::array<float, n_heads * batch_size * dim_model>
att_post1; // attention output after linear transformation, per head att_post1; // attention output after linear transformation, per head
std::array<float, batch_size * dim_model> std::array<float, kBatchSize * kModelDim>
att_post2; // accumulation of attention outputs over heads att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, batch_size * dim_model> bf_pre_ffw_rms_out; std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, batch_size * TConfig::dim_ffw_hidden * 2> ffw_hidden; std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, batch_size * 2 * TConfig::dim_ffw_hidden> // std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
// bf_ffw_hidden; // bf_ffw_hidden;
std::array<float, batch_size * dim_model> ffw_out; std::array<float, kBatchSize * kModelDim> ffw_out;
std::array<float, batch_size * TConfig::vocab_size> logits; std::array<float, kBatchSize * TConfig::kVocabSize> logits;
}; };
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
@ -288,45 +280,45 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
template <class TConfig, size_t batch_size> template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, batch_size>& activations, Activations<TConfig, kBatchSize>& activations,
const CompressedLayer<TConfig>* c_layer, const CompressedLayer<TConfig>* c_layer,
KVCache& kv_cache, hwy::ThreadPool& pool) { KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention"); PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
HWY_DASSERT(batch_idx < batch_size); HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t dim_qkv = gcpp::Activations<TConfig, 1>::dim_qkv; static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
static constexpr size_t size_cache_pos = static constexpr size_t kCachePosSize =
gcpp::Activations<TConfig, batch_size>::size_cache_pos; gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
static constexpr size_t size_cache_layer = static constexpr size_t kCacheLayerSize =
gcpp::Activations<TConfig, batch_size>::size_cache_layer; gcpp::Activations<TConfig, kBatchSize>::kCacheLayerSize;
static constexpr size_t dim_model = static constexpr size_t kModelDim =
gcpp::Activations<TConfig, batch_size>::dim_model; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t n_heads = TConfig::n_heads; static constexpr size_t kHeads = TConfig::kHeads;
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(dim_qkv)); const float kQueryScale = 1.0 / sqrtf(static_cast<float>(kQKVDim));
pool.Run(0, n_heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV // linear projections to QKV
const size_t head_offset = const size_t head_offset =
3 * dim_qkv * dim_model; // 3x for QKV dimensions 3 * kQKVDim * kModelDim; // 3x for QKV dimensions
const size_t q_offset = head * head_offset + 0 * dim_qkv * dim_model; const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim;
const size_t k_offset = head * head_offset + 1 * dim_qkv * dim_model; const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head * head_offset + 2 * dim_qkv * dim_model; const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim;
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.q.data() + head * dim_qkv + batch_idx * n_heads * dim_qkv; activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
const size_t batch_offset = batch_idx * dim_model; const size_t batch_offset = batch_idx * kModelDim;
MatVecLoop<dim_qkv, dim_model>( MatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, q_offset, c_layer->c_qkv_einsum_w, q_offset,
activations.pre_att_rms_out.data() + batch_offset, q); activations.pre_att_rms_out.data() + batch_offset, q);
const size_t kv_offset = const size_t kv_offset =
pos * size_cache_pos + layer * size_cache_layer + head * dim_qkv; pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
TwoOfsMatVecLoop<dim_qkv, dim_model>( TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset, c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset, activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset, kv_cache.key_cache.get() + kv_offset,
@ -334,119 +326,119 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
// Calculate scores // Calculate scores
float* HWY_RESTRICT head_att = activations.att.data() + float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::seq_len + head * TConfig::kSeqLen +
batch_idx * n_heads * dim_qkv; batch_idx * kHeads * kQKVDim;
Rope(q, dim_qkv, pos); Rope(q, kQKVDim, pos);
Rope(kv_cache.key_cache.get() + kv_offset, dim_qkv, pos); Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
MulByConst(kQueryScale, q, dim_qkv); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv; pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, dim_qkv); const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score; head_att[pos2] = score;
} }
Softmax(head_att, pos + 1); Softmax(head_att, pos + 1);
// Weighted summation // Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * dim_qkv + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * n_heads * dim_qkv; batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, dim_qkv * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv; pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, dim_qkv); MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
} }
// linear projection from dim_qkv back to dim_model, sum projections // linear projection from kQKVDim back to kModelDim, sum projections
// across heads // across heads
float* HWY_RESTRICT head_out = float* HWY_RESTRICT head_out =
head == 0 head == 0
? activations.att_post2.data() + batch_idx * dim_model ? activations.att_post2.data() + batch_idx * kModelDim
: activations.att_post1.data() + head * batch_size * dim_model; : activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVecLoop<dim_model, dim_qkv>(c_layer->c_attn_vec_einsum_w, MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w,
head * dim_model * dim_qkv, att_out, head * kModelDim * kQKVDim, att_out,
head_out); head_out);
}); });
// accumulate output across all heads into att_post2. head 0 already wrote // accumulate output across all heads into att_post2. head 0 already wrote
// directly to att_post2. // directly to att_post2.
for (size_t head = 1; head < n_heads; ++head) { for (size_t head = 1; head < kHeads; ++head) {
AddFrom(activations.att_post1.data() + head * batch_size * dim_model, AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim,
activations.att_post2.data() + batch_idx * dim_model, dim_model); activations.att_post2.data() + batch_idx * kModelDim, kModelDim);
} }
} }
template <typename TConfig, size_t batch_size> template <typename TConfig, size_t kBatchSize>
HWY_NOINLINE void FFW(Activations<TConfig, batch_size>& activations, HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t batch_idx, const CompressedLayer<TConfig>* c_layer, size_t batch_idx, const CompressedLayer<TConfig>* c_layer,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < batch_size); HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
const size_t hidden_offset = batch_idx * dim_ffw_hidden * 2; const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
{ {
PROFILER_ZONE("Gen.FFW.GatedGELU"); PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec = const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * dim_model; activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + dim_ffw_hidden; float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
// Same matrix, first and second half of rows. Could fuse into one MatVec, // Same matrix, first and second half of rows. Could fuse into one MatVec,
// but separating them could help on NUMA e.g. multiple sockets. // but separating them could help on NUMA e.g. multiple sockets.
MatVec<dim_ffw_hidden, dim_model>(c_layer->c_gating_einsum_w, MatVec<kFFHiddenDim, kModelDim>(c_layer->c_gating_einsum_w,
dim_ffw_hidden * dim_model, vec, out_mul, kFFHiddenDim * kModelDim, vec, out_mul,
pool); pool);
// Gate, will go through the nonlinearity. // Gate, will go through the nonlinearity.
MatVec<dim_ffw_hidden, dim_model>(c_layer->c_gating_einsum_w, 0, vec, out, MatVec<kFFHiddenDim, kModelDim>(c_layer->c_gating_einsum_w, 0, vec, out,
pool); pool);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, dim_ffw_hidden, out_mul, hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul) [](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
} }
PROFILER_ZONE("Gen.FFW\\GatedGELU"); PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVec<dim_model, dim_ffw_hidden>( MatVec<kModelDim, kFFHiddenDim>(
c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset, c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
activations.ffw_out.data() + batch_idx * dim_model, pool); activations.ffw_out.data() + batch_idx * kModelDim, pool);
} }
template <typename TConfig, size_t batch_size> template <typename TConfig, size_t kBatchSize>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const CompressedWeights<TConfig>& c_weights, const CompressedWeights<TConfig>& c_weights,
Activations<TConfig, batch_size>& activations, Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool) { hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling = sqrtf(static_cast<float>(dim_model)); static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
pool.Run( pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
const int token = tokens[token_idx]; const int token = tokens[token_idx];
Decompress(c_weights.c_embedder_input_embedding, token * dim_model, Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * dim_model, dim_model); activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * dim_model, MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
dim_model); kModelDim);
}); });
for (size_t layer = 0; layer < TConfig::n_layers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer); const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations.x.data() + token_idx * dim_model, RMSNorm(activations.x.data() + token_idx * kModelDim,
c_layer->c_pre_attention_norm_scale.data(), c_layer->c_pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * dim_model, activations.pre_att_rms_out.data() + token_idx * kModelDim,
dim_model); kModelDim);
Attention<TConfig, batch_size>(pos, token_idx, layer, activations, Attention<TConfig, kBatchSize>(pos, token_idx, layer, activations,
c_layer, kv_cache, pool); c_layer, kv_cache, pool);
} }
@ -454,22 +446,22 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
pool.Run( pool.Run(
0, num_tokens, 0, num_tokens,
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
AddFrom(activations.att_post2.data() + token_idx * dim_model, AddFrom(activations.att_post2.data() + token_idx * kModelDim,
activations.x.data() + token_idx * dim_model, dim_model); activations.x.data() + token_idx * kModelDim, kModelDim);
RMSNorm(activations.x.data() + token_idx * dim_model, RMSNorm(activations.x.data() + token_idx * kModelDim,
c_layer->c_pre_ffw_norm_scale.data(), c_layer->c_pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * dim_model, activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
dim_model); kModelDim);
FFW<TConfig, batch_size>(activations, token_idx, c_layer, inner_pool); FFW<TConfig, kBatchSize>(activations, token_idx, c_layer, inner_pool);
AddFrom(activations.ffw_out.data() + token_idx * dim_model, AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
activations.x.data() + token_idx * dim_model, dim_model); activations.x.data() + token_idx * kModelDim, kModelDim);
}); });
} // foreach layer } // foreach layer
pool.Run( pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
RMSNormInplace(c_weights.c_final_norm_scale.data(), RMSNormInplace(c_weights.c_final_norm_scale.data(),
activations.x.data() + token_idx * dim_model, dim_model); activations.x.data() + token_idx * kModelDim, kModelDim);
}); });
} }
@ -479,29 +471,29 @@ void Transformer(int token, size_t pos,
const CompressedWeights<TConfig>& c_weights, const CompressedWeights<TConfig>& c_weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache, Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
static constexpr size_t n_layers = TConfig::n_layers; static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling = sqrtf(static_cast<float>(dim_model)); static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
Decompress(c_weights.c_embedder_input_embedding, token * dim_model, Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data(), dim_model); activations.x.data(), kModelDim);
MulByConst(kEmbScaling, activations.x.data(), dim_model); MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < n_layers; ++layer) { for (size_t layer = 0; layer < kLayers; ++layer) {
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer); const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(), RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), dim_model); activations.pre_att_rms_out.data(), kModelDim);
Attention<TConfig, 1>(pos, 0, layer, activations, c_layer, kv_cache, pool); Attention<TConfig, 1>(pos, 0, layer, activations, c_layer, kv_cache, pool);
AddFrom(activations.att_post2.data(), activations.x.data(), dim_model); AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(), RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), dim_model); activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<TConfig, 1>(activations, /* batch_idx = */ 0, c_layer, pool); FFW<TConfig, 1>(activations, /* batch_idx = */ 0, c_layer, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), dim_model); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
} }
RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(), RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(),
dim_model); kModelDim);
} }
template <class TConfig> template <class TConfig>
@ -511,9 +503,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity) {
static constexpr size_t dim_model = TConfig::dim_model; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t vocab_size = TConfig::vocab_size; static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t top_k = TConfig::top_k; static constexpr size_t kTopK = TConfig::kTopK;
Activations<TConfig, 1>& activations = *gemma.state.get(); Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations = Activations<TConfig, kPrefillBatchSize>& prefill_activations =
*gemma.prefill.get(); *gemma.prefill.get();
@ -587,12 +579,12 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
if (pos_offset >= prompt.size()) { if (pos_offset >= prompt.size()) {
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Generation phase // Generation phase
MatVec<vocab_size, dim_model>(c_weights.c_embedder_input_embedding, 0, MatVec<kVocabSize, kModelDim>(c_weights.c_embedder_input_embedding, 0,
final_activation, activations.logits.data(), final_activation, activations.logits.data(),
pool); pool);
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), vocab_size); Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<top_k>(activations.logits.data(), vocab_size, gen, token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
args.temperature, accept_token); args.temperature, accept_token);
} }
if (!stream_token(token, activations.logits[token])) { if (!stream_token(token, activations.logits[token])) {
@ -643,7 +635,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
c_weights.c_final_norm_scale); c_weights.c_final_norm_scale);
char name[16]; char name[16];
for (size_t layer_idx = 0; layer_idx < TConfig::n_layers; ++layer_idx) { for (size_t layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
Layer<TConfig>* layer = weights ? &weights->layers[layer_idx] : nullptr; Layer<TConfig>* layer = weights ? &weights->layers[layer_idx] : nullptr;
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer_idx); CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer_idx);
@ -729,10 +721,10 @@ HWY_EXPORT(GetCompressedWeightsT);
HWY_EXPORT(Generate2B); HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B); HWY_EXPORT(Generate7B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { KVCache CreateKVCache(size_t size_cache_pos, size_t kSeqLen) {
KVCache kv_cache = {}; KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); kv_cache.key_cache = hwy::AllocateAligned<float>(kSeqLen * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); kv_cache.value_cache = hwy::AllocateAligned<float>(kSeqLen * size_cache_pos);
return kv_cache; return kv_cache;
} }
@ -743,8 +735,8 @@ GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool)
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()), prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()), state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
kv_cache( kv_cache(
CreateKVCache(Config::n_layers * Config::n_kv_heads * Config::dim_qkv, CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::seq_len)) { Config::kSeqLen)) {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());

View File

@ -51,9 +51,9 @@ constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
}; };
// Model variants: see configs.h for details. // Model variants: see configs.h for details.