mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into bugfix/vit_attn
This commit is contained in:
commit
ad3002a21c
|
|
@ -469,10 +469,6 @@ cc_library(
|
|||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/gemma.cc",
|
||||
"gemma/instantiations/bf16.cc",
|
||||
"gemma/instantiations/f32.cc",
|
||||
"gemma/instantiations/nuq.cc",
|
||||
"gemma/instantiations/sfp.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
|
|
|
|||
|
|
@ -61,10 +61,6 @@ set(SOURCES
|
|||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/instantiations/bf16.cc
|
||||
gemma/instantiations/f32.cc
|
||||
gemma/instantiations/nuq.cc
|
||||
gemma/instantiations/sfp.cc
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/model_store.cc
|
||||
|
|
|
|||
|
|
@ -556,7 +556,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
|||
// Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to
|
||||
// (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as
|
||||
// required to round `num` up to one vector, if it is not already. The caller is
|
||||
// responsible for scaling `raw` to the original range because `EmbedToken`
|
||||
// responsible for scaling `raw` to the original range because `EmbedMMToken`
|
||||
// also wants to scale the decompressed elements.
|
||||
// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`.
|
||||
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ struct Activations {
|
|||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
|
||||
if (!mat.HasPtr()) return;
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
|
||||
uint8_t** ptrs = row_ptrs.back().get();
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
|
|
|
|||
|
|
@ -53,30 +53,29 @@
|
|||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
#ifndef GEMMA_TYPE
|
||||
#if HWY_IDE
|
||||
// Provide a definition so the IDE does not complain.
|
||||
#define GEMMA_TYPE float
|
||||
#else
|
||||
#error "Only include from instantiations/*.cc, which must define GEMMA_TYPE"
|
||||
#endif // HWY_IDE
|
||||
#endif // GEMMA_TYPE
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename TA, typename TC>
|
||||
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
MatPtrT<TC>& C) {
|
||||
return CallUpcasted(
|
||||
&B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); });
|
||||
}
|
||||
|
||||
// Different functions use different naming conventions for the number of
|
||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||
|
||||
template <typename T>
|
||||
HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
||||
size_t num_tokens, size_t griffin_layer,
|
||||
static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
||||
size_t num_tokens,
|
||||
size_t griffin_layer,
|
||||
Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
||||
|
|
@ -101,17 +100,21 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
|||
// TODO: MatMul
|
||||
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
|
||||
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
|
||||
CallUpcastedSame(
|
||||
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
|
||||
[&](const auto* wx, const auto* wy) {
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
||||
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
||||
TwoMatVecAdd(
|
||||
*wx, *wy, 0, model_dim, model_dim,
|
||||
activations.pre_att_rms_out.Row(r),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
Gelu(y, model_dim);
|
||||
}
|
||||
});
|
||||
|
||||
// Conv1D.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
|
|
@ -165,14 +168,16 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
|||
|
||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
size_t head_offset = head * kHeadDim;
|
||||
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
|
||||
TwoOfsMatVecAddLoop(
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset,
|
||||
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
|
||||
kHeadDim, x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
model_dim + head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
});
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
Sigmoid(a + head_offset, kHeadDim);
|
||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
|
|
@ -206,17 +211,19 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
|||
|
||||
// Final linear layer.
|
||||
// TODO: MatMul
|
||||
CallUpcasted(
|
||||
&layer_weights->griffin.linear_out_w, [&](const auto* weights_t) {
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||
float* out_ptr = activations.att_sums.Row(r);
|
||||
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
||||
pool);
|
||||
MatVecAdd(*weights_t, 0, model_dim, model_dim, x,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(),
|
||||
out_ptr, pool);
|
||||
}
|
||||
});
|
||||
} // GriffinRecurrent
|
||||
|
||||
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||
template <typename T>
|
||||
class GemmaAttention {
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
// the attention window size, then it is `pos` - window_size + 1.
|
||||
|
|
@ -257,7 +264,7 @@ class GemmaAttention {
|
|||
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
|
||||
CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
|
||||
/*add=*/nullptr, *activations_.env, activations_.q);
|
||||
|
||||
// Set up MatMul row pointers for writing to KV, which consists of
|
||||
|
|
@ -279,7 +286,7 @@ class GemmaAttention {
|
|||
kv_offset);
|
||||
}
|
||||
kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0));
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
||||
CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||
|
||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||
|
|
@ -299,8 +306,11 @@ class GemmaAttention {
|
|||
|
||||
// Apply further processing to K.
|
||||
if (layer_weights_.key_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(),
|
||||
0, kv, qkv_dim);
|
||||
CallUpcasted(&layer_weights_.key_norm_scale,
|
||||
[&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0,
|
||||
kv, qkv_dim);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||
});
|
||||
|
|
@ -367,8 +377,10 @@ class GemmaAttention {
|
|||
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q,
|
||||
qkv_dim);
|
||||
CallUpcasted(&layer_weights_.query_norm_scale,
|
||||
[&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, q, qkv_dim);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
||||
|
|
@ -456,7 +468,7 @@ class GemmaAttention {
|
|||
layer_weights_.layer_config.softmax_attn_output_biases
|
||||
? layer_weights_.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
MatMulStatic(activations_.att_out, layer_weights_.att_weights, add,
|
||||
CallMatMul(activations_.att_out, layer_weights_.att_weights, add,
|
||||
*activations_.env, activations_.att_sums);
|
||||
}
|
||||
|
||||
|
|
@ -467,7 +479,7 @@ class GemmaAttention {
|
|||
GemmaAttention(const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
|
||||
activations, layer_weights, div_seq_len, kv_caches,
|
||||
|
|
@ -475,7 +487,7 @@ class GemmaAttention {
|
|||
// Constructor with default initialization to 0 for queries_prefix_end.
|
||||
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations,
|
||||
layer_weights, div_seq_len, kv_caches,
|
||||
|
|
@ -487,7 +499,7 @@ class GemmaAttention {
|
|||
GemmaAttention(const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
ThreadingContext& ctx)
|
||||
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
|
||||
|
|
@ -507,7 +519,7 @@ class GemmaAttention {
|
|||
GemmaAttention(const QueriesPos& queries_pos,
|
||||
const QueriesPos* queries_prefix_end, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
ThreadingContext& ctx)
|
||||
: queries_pos_(queries_pos),
|
||||
|
|
@ -546,20 +558,19 @@ class GemmaAttention {
|
|||
const size_t cache_pos_size_ = 0;
|
||||
|
||||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerWeightsPtrs& layer_weights_;
|
||||
const hwy::Divisor& div_seq_len_;
|
||||
const KVCaches& kv_caches_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
HWY_NOINLINE void Attention(
|
||||
static HWY_NOINLINE void Attention(
|
||||
LayerAttentionType type, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer,
|
||||
Activations& activations, const LayerWeightsPtrs<T>* layer_weights,
|
||||
Activations& activations, const LayerWeightsPtrs* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
|
||||
GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer,
|
||||
activations, layer_weights, div_seq_len, kv_caches)();
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
|
|
@ -581,7 +592,6 @@ HWY_NOINLINE void Attention(
|
|||
// This results in a much simpler implementation. However, to avoid duplicating
|
||||
// code, we should still consider merging the two classes.
|
||||
// TODO(keysers): Refactor to share code with GemmaAttention.
|
||||
template <typename T>
|
||||
class VitAttention {
|
||||
// Computes Q, K, V for all heads, stored in activations_.q.
|
||||
HWY_NOINLINE void ComputeQKV() {
|
||||
|
|
@ -589,7 +599,7 @@ class VitAttention {
|
|||
auto& qkv = activations_.q;
|
||||
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
||||
CallMatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
||||
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
|
||||
*activations_.env, qkv);
|
||||
}
|
||||
|
|
@ -631,7 +641,7 @@ class VitAttention {
|
|||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
MatMulStatic(Q, K, nullptr, *activations_.env, C);
|
||||
CallMatMul(Q, K, nullptr, *activations_.env, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
|
|
@ -697,15 +707,14 @@ class VitAttention {
|
|||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
||||
CallMatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
||||
*activations_.env, activations_.att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
VitAttention(size_t num_tokens, size_t layer, Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights)
|
||||
const LayerWeightsPtrs* layer_weights)
|
||||
: num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
|
|
@ -723,9 +732,8 @@ class VitAttention {
|
|||
|
||||
private:
|
||||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerWeightsPtrs& layer_weights_;
|
||||
const LayerConfig& layer_config_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
|
@ -775,9 +783,8 @@ void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HWY_NOINLINE void FFWNoVit(Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
static HWY_NOINLINE void FFWNoVit(Activations& activations,
|
||||
const LayerWeightsPtrs* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
|
||||
|
|
@ -789,9 +796,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
|||
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
|
||||
bias1, *activations.env, activations.C1);
|
||||
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
|
||||
bias2, *activations.env, activations.C2);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
|
|
@ -799,15 +806,14 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
|||
&activations.C2);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMulStatic(activations.C1, layer_weights->linear_w, output_bias,
|
||||
CallMatMul(activations.C1, layer_weights->linear_w, output_bias,
|
||||
*activations.env, activations.ffw_out);
|
||||
}
|
||||
|
||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||
// gating matrix.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void FFWVit(Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
static HWY_NOINLINE void FFWVit(Activations& activations,
|
||||
const LayerWeightsPtrs* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW.ViT");
|
||||
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
|
|
@ -817,14 +823,14 @@ HWY_NOINLINE void FFWVit(Activations& activations,
|
|||
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w,
|
||||
bias1, *activations.env, activations.C1);
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
|
||||
*activations.env, activations.C1);
|
||||
|
||||
// Activation (Gelu), store in C1.
|
||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
||||
CallMatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
||||
*activations.env, activations.ffw_out);
|
||||
}
|
||||
|
||||
|
|
@ -836,67 +842,52 @@ HWY_NOINLINE void FFWVit(Activations& activations,
|
|||
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
||||
// if -2 locations with appropriate begin/end image tokens are created by the
|
||||
// calling application.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
MatStorageT<float>& x,
|
||||
const ImageTokens* image_tokens,
|
||||
size_t& image_token_position) {
|
||||
// Returns new image_token_position.
|
||||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
|
||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
// Image tokens just need to be copied.
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
image_token_position < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
|
||||
x.Cols() * x.ElementBytes());
|
||||
image_token_position++;
|
||||
return;
|
||||
return image_token_position + 1;
|
||||
}
|
||||
|
||||
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return;
|
||||
return image_token_position;
|
||||
}
|
||||
|
||||
const size_t model_dim = weights.weights_config.model_dim;
|
||||
const size_t model_dim = model_config.model_dim;
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < static_cast<int>(weights.weights_config.vocab_size));
|
||||
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
// Using `Stride` to compute the offset works for both NUQ (because we use an
|
||||
// offset and NUQ is never padded) and padded, because non-NUQ types are
|
||||
CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) {
|
||||
// Using `Stride` to compute the offset works for both NUQ (because we use
|
||||
// an offset and NUQ is never padded) and padded, because non-NUQ types are
|
||||
// seekable, hence the offset can also skip any padding.
|
||||
const size_t embedding_ofs =
|
||||
token * weights.embedder_input_embedding.Stride();
|
||||
HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim);
|
||||
const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0),
|
||||
embedding_ofs + model_dim);
|
||||
const size_t embedding_ofs = token * weights_t->Stride();
|
||||
HWY_ASSERT(weights_t->Cols() == model_dim);
|
||||
const auto embedding_span =
|
||||
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
|
||||
x.Row(batch_idx), model_dim);
|
||||
if (weights.weights_config.absolute_pe) {
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
// This version of the function doesn't track internal image token position.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
MatStorageT<float>& x,
|
||||
const ImageTokens* image_tokens) {
|
||||
size_t image_token_position = 0;
|
||||
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
|
||||
image_tokens, image_token_position);
|
||||
return image_token_position;
|
||||
}
|
||||
|
||||
template <typename T2, class LayerWeights>
|
||||
|
|
@ -908,8 +899,8 @@ HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
|||
AddFromBatched(other, x);
|
||||
}
|
||||
|
||||
template <typename WeightT, typename InOutT>
|
||||
void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
|
||||
template <typename InOutT>
|
||||
void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
||||
MatPtrT<InOutT>& inout) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
if (post_norm == PostNormType::Scale) {
|
||||
|
|
@ -917,14 +908,11 @@ void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
static HWY_NOINLINE void TransformerLayer(
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
size_t num_tokens, size_t cache_layer_idx,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
Activations& activations,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
const LayerWeightsPtrs* layer_weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
auto type = layer_weights->layer_config.type;
|
||||
|
||||
RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale,
|
||||
|
|
@ -960,9 +948,8 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
|||
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
|
||||
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
|
||||
// try merging this with TransformerLayer.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
static HWY_NOINLINE void VitTransformerLayer(
|
||||
size_t num_tokens, size_t layer, const LayerWeightsPtrs* layer_weights,
|
||||
Activations& activations) {
|
||||
const size_t model_dim = activations.weights_config.model_dim;
|
||||
auto type = layer_weights->layer_config.type;
|
||||
|
|
@ -982,7 +969,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
// y ~ att_sums
|
||||
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
|
||||
VitAttention(num_tokens, layer, activations, layer_weights)();
|
||||
|
||||
// x = out["+sa"] = x + y
|
||||
AddFromBatched(activations.att_sums, x);
|
||||
|
|
@ -1005,13 +992,13 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
using QueriesMutablePos = hwy::Span<size_t>;
|
||||
|
||||
// Populates KV cache for batches of tokens from one query at a time.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void Prefill(
|
||||
static HWY_NOINLINE void Prefill(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const ModelWeightsPtrs<T>& weights,
|
||||
Activations& activations, const RuntimeConfig& runtime_config,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
const size_t query_idx_start, const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
|
|
@ -1072,14 +1059,14 @@ HWY_NOINLINE void Prefill(
|
|||
const size_t pos = queries_pos[qi] + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x,
|
||||
image_token_position = EmbedMMToken(
|
||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0;
|
||||
layer < weights.weights_config.layer_configs.size(); ++layer) {
|
||||
const LayerWeightsPtrs<T>* layer_weights = weights.GetLayer(layer);
|
||||
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
||||
const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer(single_query_pos, single_query_prefix_end, tbatch_size,
|
||||
layer, layer_weights, activations, div_seq_len,
|
||||
single_kv_cache);
|
||||
|
|
@ -1115,13 +1102,13 @@ HWY_NOINLINE void Prefill(
|
|||
|
||||
// Gets the patches of the image and embeds them with the image embedding
|
||||
// kernel. The result is stored in activations.x.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||
const ModelConfig& model_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
Activations& activations) {
|
||||
const size_t model_dim = weights.weights_config.vit_config.model_dim;
|
||||
const size_t patch_width = weights.weights_config.vit_config.patch_width;
|
||||
const size_t seq_len = weights.weights_config.vit_config.seq_len;
|
||||
const size_t model_dim = model_config.vit_config.model_dim;
|
||||
const size_t patch_width = model_config.vit_config.patch_width;
|
||||
const size_t seq_len = model_config.vit_config.seq_len;
|
||||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||
|
|
@ -1138,7 +1125,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
|
||||
// kPatchSize), MatPadding::kPacked);
|
||||
// [Get patches]
|
||||
// MatMulStatic(
|
||||
// CallMatMul(
|
||||
// image_patches,
|
||||
// weights.vit_img_embedding_kernel,
|
||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
|
|
@ -1147,61 +1134,64 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
// then use the above. For now, we rely on MatVecAdd instead.
|
||||
CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) {
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||
image_patches[i].get(),
|
||||
MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(),
|
||||
weights.vit_img_embedding_bias.PackedScale1(),
|
||||
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
|
||||
}
|
||||
});
|
||||
// Add position embeddings.
|
||||
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
|
||||
}
|
||||
|
||||
// Prefills the image tokens with the ViT encoder.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
||||
static HWY_NOINLINE void PrefillVit(const ModelConfig& model_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
const Image& image,
|
||||
ImageTokens& image_tokens,
|
||||
Activations& activations) {
|
||||
PROFILER_ZONE("Gen.PrefillVit");
|
||||
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
|
||||
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
|
||||
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||
const size_t vit_model_dim = model_config.vit_config.model_dim;
|
||||
HWY_ASSERT(num_tokens == activations.x.Rows());
|
||||
// Embed the image patches.
|
||||
EmbedImagePatches(image, weights, activations);
|
||||
EmbedImagePatches(image, model_config, weights, activations);
|
||||
// Go through all layers.
|
||||
for (size_t layer = 0;
|
||||
layer < weights.weights_config.vit_config.layer_configs.size();
|
||||
for (size_t layer = 0; layer < model_config.vit_config.layer_configs.size();
|
||||
++layer) {
|
||||
const LayerWeightsPtrs<T>* layer_weights = weights.VitLayer(layer);
|
||||
const LayerWeightsPtrs* layer_weights = weights.VitLayer(layer);
|
||||
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
|
||||
}
|
||||
// Final Layernorm.
|
||||
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
|
||||
weights.vit_encoder_norm_bias, activations.x);
|
||||
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
activations.x = AvgPool4x4(activations.x);
|
||||
|
||||
// Apply soft embedding norm before input projection.
|
||||
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0,
|
||||
activations.x.Row(0), vit_model_dim);
|
||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
|
||||
vit_model_dim);
|
||||
});
|
||||
}
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMulStatic(activations.x, weights.vit_img_head_kernel,
|
||||
CallMatMul(activations.x, weights.vit_img_head_kernel,
|
||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||
image_tokens);
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
// from each query, and `queries_pos` are their position in the sequence.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void Transformer(
|
||||
static HWY_NOINLINE void Transformer(
|
||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const ModelWeightsPtrs<T>& weights,
|
||||
Activations& activations, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, const LayersOutputFunc& layers_output,
|
||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
const LayersOutputFunc& layers_output,
|
||||
const ActivationsObserverFunc& activations_observer) {
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
|
|
@ -1215,15 +1205,13 @@ HWY_NOINLINE void Transformer(
|
|||
}
|
||||
}
|
||||
|
||||
size_t image_token_position = 0;
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
|
||||
/*pos_in_prompt=*/0, weights, activations.x,
|
||||
/*image_tokens=*/nullptr, image_token_position);
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) {
|
||||
const LayerWeightsPtrs<T>* layer_weights = weights.GetLayer(layer);
|
||||
const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer(queries_pos, queries_prefix_end, /*num_tokens=*/1, layer,
|
||||
layer_weights, activations, div_seq_len, kv_caches);
|
||||
|
||||
|
|
@ -1302,25 +1290,25 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
// Runs one decode step for all the queries in the batch. Returns true if all
|
||||
// queries are at <end_of_sentence>.
|
||||
bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
||||
static bool DecodeStepT(const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const hwy::Divisor div_seq_len, const size_t vocab_size,
|
||||
const SampleFunc& sample_token, Activations& activations,
|
||||
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
|
||||
TimingInfo& timing_info,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, TokenStreamer& token_streamer,
|
||||
std::vector<int>& gen_tokens, TimingInfo& timing_info,
|
||||
const QueriesMutablePos& queries_mutable_pos) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
// Decode generates one token per query and increments
|
||||
// queries_mutable_pos.
|
||||
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
||||
queries_prefix_end, weights, activations, div_seq_len, kv_caches,
|
||||
runtime_config.layers_output,
|
||||
queries_prefix_end, config, weights, activations, div_seq_len,
|
||||
kv_caches, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
|
|
@ -1329,13 +1317,13 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
|||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMulStatic(activations.x, weights.embedder_input_embedding,
|
||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, *activations.env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
|
||||
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, vocab_size);
|
||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
|
|
@ -1359,15 +1347,16 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
|||
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||
//
|
||||
// `kv_caches` is for the batch, size must match `queries_prompt`.
|
||||
template <typename T>
|
||||
void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
||||
Activations& activations, const RuntimeConfig& runtime_config,
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
|
||||
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
if (queries_pos_in[i] == 0) {
|
||||
|
|
@ -1384,7 +1373,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != model.Config().eos_id);
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||
}
|
||||
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
|
|
@ -1395,7 +1384,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
|
||||
RangeChecks(config, max_generated_tokens, max_prompt_size);
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||
|
||||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||
|
|
@ -1403,8 +1392,8 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
timing_info.prefill_start = hwy::platform::Now();
|
||||
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
|
||||
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
query_idx_start, weights, activations, runtime_config, div_seq_len,
|
||||
kv_caches);
|
||||
query_idx_start, config, weights, activations, runtime_config,
|
||||
div_seq_len, kv_caches);
|
||||
// Compute the number of tokens that were prefilled and notify timing_info.
|
||||
size_t prefilled_tokens = 0;
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
|
|
@ -1419,7 +1408,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
std::vector<int> gen_tokens(num_queries);
|
||||
|
||||
// Stream the last prompt token from each query and fill gen_tokens.
|
||||
TokenStreamer token_streamer(runtime_config, model.Config());
|
||||
TokenStreamer token_streamer(runtime_config, config);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
size_t last_token_pos_in_prompt =
|
||||
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
|
||||
|
|
@ -1430,26 +1419,24 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
}
|
||||
|
||||
{
|
||||
const size_t vocab_size = model.Config().vocab_size;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||
bool all_queries_eos = DecodeStepT<T>(
|
||||
model.Config(), weights, runtime_config, queries_prompt,
|
||||
query_idx_start, kv_caches, queries_prefix_end, div_seq_len,
|
||||
vocab_size, sample_token, activations, token_streamer, gen_tokens,
|
||||
timing_info, queries_mutable_pos);
|
||||
bool all_queries_eos = DecodeStepT(
|
||||
config, weights, runtime_config, queries_prompt, query_idx_start,
|
||||
kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token,
|
||||
activations, token_streamer, gen_tokens, timing_info,
|
||||
queries_mutable_pos);
|
||||
if (all_queries_eos) break;
|
||||
} // foreach token to generate
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateSingleT(const ModelStore& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
static HWY_MAYBE_UNUSED void GenerateSingleT(
|
||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
|
@ -1457,26 +1444,24 @@ void GenerateSingleT(const ModelStore& model,
|
|||
const size_t max_batch_size =
|
||||
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
|
||||
// TODO: move into Gemma?
|
||||
Activations activations(model.Config(), max_batch_size, env);
|
||||
Activations activations(config, max_batch_size, env);
|
||||
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT<T>(model, weights, activations, runtime_config, queries_prompt,
|
||||
GenerateT(config, weights, activations, runtime_config, queries_prompt,
|
||||
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
||||
static HWY_MAYBE_UNUSED void GenerateBatchT(
|
||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
|
||||
MatMulEnv* env, TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
||||
|
|
@ -1484,7 +1469,7 @@ void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
const size_t max_batch_size =
|
||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||
|
||||
Activations activations(model.Config(), max_batch_size, env);
|
||||
Activations activations(config, max_batch_size, env);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
|
|
@ -1497,68 +1482,31 @@ void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
|
||||
GenerateT(config, weights, activations, runtime_config, qbatch_prompts,
|
||||
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
|
||||
timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateImageTokensT(const ModelStore& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
MatMulEnv* env) {
|
||||
if (model.Config().vit_config.layer_configs.empty()) {
|
||||
static HWY_MAYBE_UNUSED void GenerateImageTokensT(
|
||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv* env) {
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
ModelConfig vit_config = GetVitConfig(model.Config());
|
||||
ModelConfig vit_config = GetVitConfig(config);
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, vit_config.seq_len, env);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
// These are extern functions defined by instantiations/*.cc, which include this
|
||||
// 'header' after defining `GEMMA_TYPE`.
|
||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_TYPE>)
|
||||
(model, weights, runtime_config, prompt, pos, prefix_end, kv_cache, env,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
|
||||
MatMulEnv* env, TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_TYPE>)
|
||||
(model, weights, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, kv_caches, env, timing_info);
|
||||
}
|
||||
|
||||
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
|
||||
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv* env) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)
|
||||
(model, weights, runtime_config, image, image_tokens, env);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
|
|
|
|||
|
|
@ -13,21 +13,33 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Defines Gemma member functions; the actual implementations are in
|
||||
// gemma-inl.h, included from instantiations/*.cc.
|
||||
// Defines Gemma member functions which dynamic-dispatch into the SIMD
|
||||
// implementations in gemma-inl.h.
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
||||
#ifndef GEMMA_CC_ONCE
|
||||
#define GEMMA_CC_ONCE
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
|
|
@ -39,7 +51,13 @@
|
|||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#endif // GEMMA_CC_ONCE
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_EXPORT(GenerateSingleT);
|
||||
HWY_EXPORT(GenerateBatchT);
|
||||
HWY_EXPORT(GenerateImageTokensT);
|
||||
|
||||
// Internal init must run before I/O. This helper function takes care of that,
|
||||
// plus calling `SetArgs`.
|
||||
|
|
@ -52,12 +70,13 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
|||
|
||||
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
||||
: env_(env),
|
||||
reader_(new BlobReader(loader.weights)),
|
||||
model_(*reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config().weight),
|
||||
reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
weights_.ReadFromBlobs(model_, *reader_, loader.map, env_.ctx.pools.Pool());
|
||||
reader_.reset();
|
||||
weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_,
|
||||
env.ctx.pools.Pool());
|
||||
reader_.CloseFile();
|
||||
}
|
||||
|
||||
Gemma::~Gemma() = default;
|
||||
|
|
@ -70,42 +89,14 @@ void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
|
|||
writer, env_.ctx.pools.Pool(), weights_path);
|
||||
}
|
||||
|
||||
// There are >=3 types of the inference code. To reduce compile time,
|
||||
// we shard them across multiple translation units in instantiations/*.cc.
|
||||
// This declares the functions defined there. We use overloading because
|
||||
// explicit instantiations are still too slow to compile.
|
||||
// TODO: we want to move toward type-erasing, where we check the tensor type at
|
||||
// each usage. Then we would have a single function, passing `WeightsOwner`
|
||||
// instead of `WeightsPtrs<T>`.
|
||||
#define GEMMA_DECLARE(WEIGHT_TYPE) \
|
||||
extern void GenerateSingle( \
|
||||
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, \
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
|
||||
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens( \
|
||||
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, MatMulEnv* env);
|
||||
GEMMA_DECLARE(float)
|
||||
GEMMA_DECLARE(BF16)
|
||||
GEMMA_DECLARE(NuqStream)
|
||||
GEMMA_DECLARE(SfpStream)
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
weights_.CallT([&](auto& weights) {
|
||||
GenerateSingle(model_, *weights, runtime_config, prompt, pos, prefix_end,
|
||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_,
|
||||
runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, &env_, timing_info);
|
||||
});
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
@ -127,11 +118,9 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
weights_.CallT([&](auto& weights) {
|
||||
gcpp::GenerateBatch(model_, *weights, runtime_config, queries_prompt,
|
||||
queries_pos, mutable_queries_prefix_end, kv_caches,
|
||||
&env_, timing_info);
|
||||
});
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
|
||||
model_.Config(), weights_, runtime_config, queries_prompt, queries_pos,
|
||||
mutable_queries_prefix_end, kv_caches, &env_, timing_info);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
@ -141,12 +130,11 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
|||
ImageTokens& image_tokens) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
weights_.CallT([&](auto& weights) {
|
||||
gcpp::GenerateImageTokens(model_, *weights, runtime_config, image,
|
||||
image_tokens, &env_);
|
||||
});
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
|
||||
model_.Config(), weights_, runtime_config, image, image_tokens, &env_);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "gemma/activations.h"
|
||||
|
|
@ -113,11 +113,9 @@ class Gemma {
|
|||
// TODO: rename to Config()
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
const WeightsOwner& Weights() const { return weights_; }
|
||||
const ModelWeightsPtrs& Weights() const { return weights_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
|
||||
// For tests.
|
||||
WeightsOwner& MutableWeights() { return weights_; }
|
||||
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
|
|
@ -154,9 +152,10 @@ class Gemma {
|
|||
|
||||
private:
|
||||
MatMulEnv& env_;
|
||||
std::unique_ptr<BlobReader> reader_; // null for second ctor
|
||||
BlobReader reader_;
|
||||
ModelStore model_;
|
||||
WeightsOwner weights_;
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
ModelWeightsPtrs weights_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
|
||||
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_TYPE hwy::bfloat16_t
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_TYPE float
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/nuq.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_TYPE NuqStream
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_TYPE SfpStream
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -372,6 +372,11 @@ bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
|
|||
// `Compress()` output is always packed because it assumes a 1D array.
|
||||
HWY_ASSERT(mat.IsPacked());
|
||||
// Update fields. Name already matched, otherwise we would not find it.
|
||||
// For MatPtr tensors, the type will be `kUnknown`. If it was a `MatPtrT`,
|
||||
// ensure the type set via code matches the file.
|
||||
HWY_ASSERT_M(
|
||||
mat.GetType() == Type::kUnknown || mat.GetType() == file_mat->GetType(),
|
||||
mat.Name());
|
||||
mat.SetType(file_mat->GetType());
|
||||
if (scales_.empty()) {
|
||||
// `file_mat->Scale()` is either read from file, or we have pre-2025 format
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
|
|
|
|||
|
|
@ -21,13 +21,14 @@ TEST(TensorInfoRegistryTest, Find) {
|
|||
config.Specifier().c_str());
|
||||
const TensorInfoRegistry tensors(config);
|
||||
// Each tensor in the model should be known/found.
|
||||
ModelWeightsPtrs<SfpStream> weights(config);
|
||||
ModelWeightsPtrs weights(config);
|
||||
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
|
||||
const TensorInfo* info = tensors.Find(t.mat.Name());
|
||||
HWY_ASSERT_M(info, t.mat.Name());
|
||||
// Test that the `MatPtr` can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
MatPtrT<SfpStream> mat_ptr(t.mat.Name(), tensors);
|
||||
const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown,
|
||||
ExtentsFromInfo(tensors.Find(t.mat.Name())));
|
||||
EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name();
|
||||
EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name();
|
||||
EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name();
|
||||
|
|
|
|||
239
gemma/weights.cc
239
gemma/weights.cc
|
|
@ -20,7 +20,7 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -42,15 +42,127 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
static void InitAttWeightsNUQ(const LayerConfig& layer_config,
|
||||
MatPtrT<NuqStream>& attn_vec_einsum_w,
|
||||
MatPtrT<NuqStream>& att_weights,
|
||||
std::vector<MatOwner>& mat_owners) {
|
||||
// Copies att_weights from `attn_vec_einsum_w`.
|
||||
void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners) {
|
||||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
// Files must have one or the other.
|
||||
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
|
||||
// Done if we already read the transposed tensor.
|
||||
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
|
||||
|
||||
// NUQ is handled by a specialization in weights.cc.
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
|
||||
att_weights.SetType(attn_vec_einsum_w.GetType());
|
||||
HWY_ASSERT(att_weights.Rows() == model_dim);
|
||||
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
|
||||
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
|
||||
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
|
||||
|
||||
{
|
||||
static std::mutex m;
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
mat_owners.push_back(MatOwner());
|
||||
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
|
||||
}
|
||||
|
||||
const size_t T_bytes = att_weights.ElementBytes();
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
uint8_t* HWY_RESTRICT out_row = att_weights.RowBytes(m);
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
hwy::CopyBytes(attn_vec_einsum_w.RowBytes(h * model_dim + m),
|
||||
out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes);
|
||||
}
|
||||
}
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
||||
// For FFN. Fast, only updates pointers.
|
||||
void LayerWeightsPtrs::SplitW1() {
|
||||
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
||||
if (layer_config.type == LayerAttentionType::kVit) return;
|
||||
|
||||
// Files have both or neither of w1 and w2.
|
||||
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||
// w is mutually exclusive with w1 and w2 in the file.
|
||||
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors. Note that they are not
|
||||
// necessarily the same type.
|
||||
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
|
||||
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
|
||||
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
|
||||
// Cols are the model_dim but we don't have ModelConfig here.
|
||||
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
|
||||
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
|
||||
|
||||
const size_t stride = gating_einsum_w.Stride();
|
||||
gating_einsum_w1.SetPtr(gating_einsum_w.RowBytes(0), stride);
|
||||
gating_einsum_w2.SetPtr(gating_einsum_w.RowBytes(ff_hidden_dim), stride);
|
||||
gating_einsum_w1.SetType(gating_einsum_w.GetType());
|
||||
gating_einsum_w2.SetType(gating_einsum_w.GetType());
|
||||
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
|
||||
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
|
||||
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
|
||||
}
|
||||
|
||||
// For attention, which might not have a w2. Fast, only updates pointers.
|
||||
void LayerWeightsPtrs::SplitAttW1() {
|
||||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
// w is mutually exclusive with w1 in the file.
|
||||
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors. Note that w2 does not exist for
|
||||
// MHA, and otherwise might not be the same type.
|
||||
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
|
||||
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
|
||||
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
|
||||
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
||||
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
|
||||
// Cols are the model_dim but we don't have ModelConfig here.
|
||||
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
|
||||
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
|
||||
|
||||
const size_t stride = qkv_einsum_w.Stride();
|
||||
qkv_einsum_w1.SetPtr(qkv_einsum_w.RowBytes(0), stride);
|
||||
qkv_einsum_w2.SetPtr(qkv_einsum_w.RowBytes(w1_rows), stride);
|
||||
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
|
||||
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
|
||||
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
|
||||
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
|
||||
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
||||
}
|
||||
|
||||
// Must be called after reading weights via `ForEachTensor`.
|
||||
// TODO: exporters should bake this into the weights already.
|
||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners) {
|
||||
// TODO(janwas): handle NUQ
|
||||
InitAttWeights(mat_owners);
|
||||
SplitW1();
|
||||
SplitAttW1();
|
||||
}
|
||||
|
||||
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
|
||||
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
|
||||
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners) {
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
|
||||
|
||||
HWY_ASSERT(att_weights.HasPtr());
|
||||
HWY_ASSERT(att_weights.GetType() == Type::kNUQ);
|
||||
att_weights.SetType(Type::kNUQ);
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
|
|
@ -85,14 +197,53 @@ static void InitAttWeightsNUQ(const LayerConfig& layer_config,
|
|||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
||||
static void SplitW1NUQ(const LayerConfig& layer_config) {
|
||||
static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) {
|
||||
// TODO(janwas): implement.
|
||||
}
|
||||
|
||||
template <>
|
||||
void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
|
||||
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
|
||||
SplitW1NUQ(layer_config);
|
||||
// Zero-initializes only the allocated tensors in `*this`.
|
||||
void ModelWeightsPtrs::ZeroInit() {
|
||||
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
gcpp::ZeroInit(t.mat);
|
||||
});
|
||||
}
|
||||
|
||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||
void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
|
||||
ForEachTensor(const_cast<ModelWeightsPtrs*>(&other), nullptr,
|
||||
[](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
||||
CopyMat(*t.other_mat1, t.mat);
|
||||
});
|
||||
}
|
||||
|
||||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
|
||||
void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners);
|
||||
});
|
||||
|
||||
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
|
||||
BlobWriter& writer) const {
|
||||
std::vector<uint32_t> serialized_mat_ptrs;
|
||||
// ForEachTensor is non-const but the lambda does not modify *this.
|
||||
const_cast<ModelWeightsPtrs*>(this)->ForEachTensor(
|
||||
nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
||||
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
||||
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
|
||||
t.mat.AppendTo(serialized_mat_ptrs);
|
||||
});
|
||||
return serialized_mat_ptrs;
|
||||
}
|
||||
|
||||
struct TensorToRead {
|
||||
|
|
@ -144,7 +295,7 @@ static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
|
|||
return Mode::kRead;
|
||||
}
|
||||
|
||||
MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
||||
static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
if (file_bytes % allocator.BasePageBytes() == 0) {
|
||||
MapPtr mapped = file.Map();
|
||||
|
|
@ -175,8 +326,8 @@ static void MapAll(const std::vector<TensorToRead>& tensors,
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<IOBatch> MakeBatches(const std::vector<TensorToRead>& tensors,
|
||||
const uint64_t file_bytes) {
|
||||
static std::vector<IOBatch> MakeBatches(
|
||||
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
|
||||
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||
// Batches must be contiguous but blobs are padded, hence at least one
|
||||
// batch per tensor, and more when tensor rows exceed the batch size.
|
||||
|
|
@ -254,72 +405,34 @@ static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
|
|||
ReadBatches(reader, batches, pool);
|
||||
}
|
||||
|
||||
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||
Tristate map, hwy::ThreadPool& pool) {
|
||||
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
||||
BlobReader& reader, Tristate map,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
// List of tensors to read/map, and where from.
|
||||
std::vector<TensorToRead> tensors;
|
||||
|
||||
AllocatePointer(model.Config());
|
||||
|
||||
// Enumerate all weights (negligible cost).
|
||||
CallT([&](const auto& weights) {
|
||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
const MatPadding padding = (t.flags & TensorArgs::kPacked)
|
||||
? MatPadding::kPacked
|
||||
: MatPadding::kOdd;
|
||||
size_t key_idx;
|
||||
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
||||
tensors.push_back({.mat = &t.mat,
|
||||
.range = reader.Range(key_idx),
|
||||
.padding = padding});
|
||||
tensors.push_back(
|
||||
{.mat = &t.mat, .range = reader.Range(key_idx), .padding = padding});
|
||||
return;
|
||||
}
|
||||
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
|
||||
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
||||
});
|
||||
});
|
||||
|
||||
MapOrReadAll(tensors, reader, map, mat_owners_, pool);
|
||||
MapOrReadAll(tensors, reader, map, mat_owners, pool);
|
||||
|
||||
Fixup(pool);
|
||||
}
|
||||
|
||||
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
||||
// of `CallT` because that is const, hence it would pass a const& of the
|
||||
// `unique_ptr` to its lambda, but we want to reset the pointer.
|
||||
void WeightsOwner::AllocatePointer(const ModelConfig& config) {
|
||||
switch (weight_type_) {
|
||||
case Type::kSFP:
|
||||
sfp_weights_.reset(new ModelWeightsPtrs<SfpStream>(config));
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
nuq_weights_.reset(new ModelWeightsPtrs<NuqStream>(config));
|
||||
break;
|
||||
case Type::kBF16:
|
||||
bf16_weights_.reset(new ModelWeightsPtrs<BF16>(config));
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
|
||||
{
|
||||
PROFILER_ZONE("Startup.Fixup");
|
||||
Fixup(mat_owners, pool);
|
||||
}
|
||||
}
|
||||
|
||||
void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.Fixup");
|
||||
CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });
|
||||
}
|
||||
|
||||
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
|
||||
BlobWriter& writer) const {
|
||||
std::vector<uint32_t> serialized_mat_ptrs;
|
||||
CallT([&](const auto& weights) {
|
||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
||||
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
||||
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
|
||||
t.mat.AppendTo(serialized_mat_ptrs);
|
||||
});
|
||||
});
|
||||
return serialized_mat_ptrs;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
479
gemma/weights.h
479
gemma/weights.h
|
|
@ -19,18 +19,14 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <complex>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // IsF32
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
||||
#include "io/blob_store.h" // BlobWriter
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "util/mat.h" // MatPtr
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -73,142 +69,141 @@ struct TensorArgs {
|
|||
TensorArgs(mat, other1 ? &other1->mat : nullptr, \
|
||||
other2 ? &other2->mat : nullptr, TensorArgs::flag)
|
||||
|
||||
// Per-layer weight metadata and pointers. The tensor data is owned by
|
||||
// `WeightsOwner`. Note that this class could be type-erased: member functions
|
||||
// do not actually use the `Weight` template argument. See `WeightsPtrs`.
|
||||
// `TensorInfoRegistry` (constructed from `ModelConfig`) is the source of truth
|
||||
// for all tensor shapes.
|
||||
template <class Weight>
|
||||
struct LayerWeightsPtrs {
|
||||
static inline std::string Concat(const char* base_name,
|
||||
const std::string& suffix) {
|
||||
return std::string(base_name) + suffix;
|
||||
// Finds tensors by name in `TensorInfoRegistry` (constructed from
|
||||
// `ModelConfig`) and constructs `MatPtr` metadata with those shapes.
|
||||
class MatFinder {
|
||||
public:
|
||||
MatFinder(const std::string& suffix, const TensorInfoRegistry& tensors)
|
||||
: suffix_(suffix), tensors_(tensors) {}
|
||||
|
||||
// Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`.
|
||||
MatPtr operator()(const std::string& base_name) const {
|
||||
const std::string name = std::string(base_name) + suffix_;
|
||||
return MatPtr(name.c_str(), Type::kUnknown,
|
||||
ExtentsFromInfo(tensors_.Find(name)));
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string suffix_;
|
||||
const TensorInfoRegistry& tensors_;
|
||||
};
|
||||
|
||||
// Per-layer weight metadata and pointers. The tensor data is owned by
|
||||
// `WeightsOwner`.
|
||||
struct LayerWeightsPtrs {
|
||||
// Initializes tensor metadata without allocating.
|
||||
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
||||
const TensorInfoRegistry& tensors)
|
||||
: suffix_(LayerSuffix(layer_idx)),
|
||||
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
|
||||
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
|
||||
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
|
||||
attention_output_biases(Concat("attn_ob", suffix_), tensors),
|
||||
griffin(
|
||||
{.linear_x_w = {Concat("gr_lin_x_w", suffix_), tensors},
|
||||
.linear_x_biases = {Concat("gr_lin_x_b", suffix_), tensors},
|
||||
.linear_y_w = {Concat("gr_lin_y_w", suffix_), tensors},
|
||||
.linear_y_biases = {Concat("gr_lin_y_b", suffix_), tensors},
|
||||
.linear_out_w = {Concat("gr_lin_out_w", suffix_), tensors},
|
||||
.linear_out_biases = {Concat("gr_lin_out_b", suffix_), tensors},
|
||||
.conv_w = {Concat("gr_conv_w", suffix_), tensors},
|
||||
.conv_biases = {Concat("gr_conv_b", suffix_), tensors},
|
||||
.gate_w = {Concat("gr_gate_w", suffix_), tensors},
|
||||
.gate_biases = {Concat("gr_gate_b", suffix_), tensors},
|
||||
.a = {Concat("gr_a", suffix_), tensors}}),
|
||||
: finder_(LayerSuffix(layer_idx), tensors),
|
||||
qkv_einsum_w(finder_("qkv_ein")),
|
||||
qkv_einsum_w1(finder_("qkv1_w")),
|
||||
qkv_einsum_w2(finder_("qkv2_w")),
|
||||
attention_output_biases(finder_("attn_ob")),
|
||||
griffin({.linear_x_w = finder_("gr_lin_x_w"),
|
||||
.linear_x_biases = finder_("gr_lin_x_b"),
|
||||
.linear_y_w = finder_("gr_lin_y_w"),
|
||||
.linear_y_biases = finder_("gr_lin_y_b"),
|
||||
.linear_out_w = finder_("gr_lin_out_w"),
|
||||
.linear_out_biases = finder_("gr_lin_out_b"),
|
||||
.conv_w = finder_("gr_conv_w"),
|
||||
.conv_biases = finder_("gr_conv_b"),
|
||||
.gate_w = finder_("gr_gate_w"),
|
||||
.gate_biases = finder_("gr_gate_b"),
|
||||
.a = finder_("gr_a")}),
|
||||
// MultiHeadDotProductAttention.
|
||||
vit({.attn_out_w = {Concat("attn_out_w", suffix_), tensors},
|
||||
.attn_out_b = {Concat("attn_out_b", suffix_), tensors},
|
||||
.qkv_einsum_w = {Concat("qkv_ein_w", suffix_), tensors},
|
||||
.qkv_einsum_b = {Concat("qkv_ein_b", suffix_), tensors},
|
||||
.linear_0_w = {Concat("linear_0_w", suffix_), tensors},
|
||||
.linear_0_b = {Concat("linear_0_b", suffix_), tensors},
|
||||
.linear_1_w = {Concat("linear_1_w", suffix_), tensors},
|
||||
.linear_1_b = {Concat("linear_1_b", suffix_), tensors},
|
||||
.layer_norm_0_bias = {Concat("ln_0_bias", suffix_), tensors},
|
||||
.layer_norm_0_scale = {Concat("ln_0_scale", suffix_), tensors},
|
||||
.layer_norm_1_bias = {Concat("ln_1_bias", suffix_), tensors},
|
||||
.layer_norm_1_scale = {Concat("ln_1_scale", suffix_), tensors}}),
|
||||
gating_einsum_w(Concat("gating_ein", suffix_), tensors),
|
||||
gating_einsum_w1(Concat("gating1_w", suffix_), tensors),
|
||||
gating_einsum_w2(Concat("gating2_w", suffix_), tensors),
|
||||
linear_w(Concat("linear_w", suffix_), tensors),
|
||||
pre_attention_norm_scale(Concat("pre_att_ns", suffix_), tensors),
|
||||
pre_ffw_norm_scale(Concat("pre_ff_ns", suffix_), tensors),
|
||||
post_attention_norm_scale(Concat("post_att_ns", suffix_), tensors),
|
||||
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
|
||||
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
|
||||
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
|
||||
vit({.attn_out_w = finder_("attn_out_w"),
|
||||
.attn_out_b = finder_("attn_out_b"),
|
||||
.qkv_einsum_w = finder_("qkv_ein_w"),
|
||||
.qkv_einsum_b = finder_("qkv_ein_b"),
|
||||
.linear_0_w = finder_("linear_0_w"),
|
||||
.linear_0_b = finder_("linear_0_b"),
|
||||
.linear_1_w = finder_("linear_1_w"),
|
||||
.linear_1_b = finder_("linear_1_b"),
|
||||
.layer_norm_0_bias = finder_("ln_0_bias"),
|
||||
.layer_norm_0_scale = finder_("ln_0_scale"),
|
||||
.layer_norm_1_bias = finder_("ln_1_bias"),
|
||||
.layer_norm_1_scale = finder_("ln_1_scale")}),
|
||||
gating_einsum_w(finder_("gating_ein")),
|
||||
gating_einsum_w1(finder_("gating1_w")),
|
||||
gating_einsum_w2(finder_("gating2_w")),
|
||||
linear_w(finder_("linear_w")),
|
||||
pre_attention_norm_scale(finder_("pre_att_ns")),
|
||||
pre_ffw_norm_scale(finder_("pre_ff_ns")),
|
||||
post_attention_norm_scale(finder_("post_att_ns")),
|
||||
post_ffw_norm_scale(finder_("post_ff_ns")),
|
||||
ffw_gating_biases(finder_("ffw_gat_b")),
|
||||
ffw_output_biases(finder_("ffw_out_b")),
|
||||
|
||||
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
|
||||
att_weights(Concat("att_w", suffix_), tensors),
|
||||
attn_vec_einsum_w(finder_("att_ein")),
|
||||
att_weights(finder_("att_w")),
|
||||
|
||||
key_norm_scale(Concat("key_norm", suffix_), tensors),
|
||||
query_norm_scale(Concat("query_norm", suffix_), tensors),
|
||||
key_norm_scale(finder_("key_norm")),
|
||||
query_norm_scale(finder_("query_norm")),
|
||||
|
||||
layer_config(config) {
|
||||
}
|
||||
~LayerWeightsPtrs() = default;
|
||||
|
||||
const std::string suffix_;
|
||||
|
||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
||||
// do not yet support smaller compressed types, or require at least bf16. When
|
||||
// weights are f32, we also want such tensors to be f32.
|
||||
// If weights are complex, this is also complex.
|
||||
using WeightF32OrBF16 =
|
||||
hwy::If<hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
|
||||
hwy::If<hwy::IsSame<Weight, double>(), double,
|
||||
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
||||
const MatFinder finder_;
|
||||
|
||||
// Files either have qkv_einsum_w with 2 stacked matrices or separate
|
||||
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
||||
MatPtrT<Weight> qkv_einsum_w;
|
||||
MatPtrT<Weight> qkv_einsum_w1;
|
||||
MatPtrT<Weight> qkv_einsum_w2;
|
||||
MatPtr qkv_einsum_w;
|
||||
MatPtr qkv_einsum_w1;
|
||||
MatPtr qkv_einsum_w2;
|
||||
MatPtrT<float> attention_output_biases;
|
||||
|
||||
struct {
|
||||
MatPtrT<Weight> linear_x_w;
|
||||
MatPtr linear_x_w;
|
||||
MatPtrT<float> linear_x_biases;
|
||||
MatPtrT<Weight> linear_y_w;
|
||||
MatPtr linear_y_w;
|
||||
MatPtrT<float> linear_y_biases;
|
||||
MatPtrT<Weight> linear_out_w;
|
||||
MatPtr linear_out_w;
|
||||
MatPtrT<float> linear_out_biases;
|
||||
MatPtrT<float> conv_w;
|
||||
MatPtrT<float> conv_biases;
|
||||
MatPtrT<Weight> gate_w;
|
||||
MatPtr gate_w;
|
||||
MatPtrT<float> gate_biases;
|
||||
MatPtrT<float> a;
|
||||
} griffin;
|
||||
|
||||
struct {
|
||||
// MultiHeadDotProductAttention.
|
||||
MatPtrT<WeightF32OrBF16> attn_out_w;
|
||||
MatPtr attn_out_w; // at least BF16.
|
||||
MatPtrT<float> attn_out_b;
|
||||
MatPtrT<WeightF32OrBF16> qkv_einsum_w;
|
||||
MatPtr qkv_einsum_w; // at least BF16.
|
||||
MatPtrT<float> qkv_einsum_b;
|
||||
// MlpBlock.
|
||||
MatPtrT<WeightF32OrBF16> linear_0_w;
|
||||
MatPtr linear_0_w; // at least BF16.
|
||||
MatPtrT<float> linear_0_b;
|
||||
MatPtrT<WeightF32OrBF16> linear_1_w;
|
||||
MatPtr linear_1_w; // at least BF16.
|
||||
MatPtrT<float> linear_1_b;
|
||||
// LayerNorm.
|
||||
MatPtrT<WeightF32OrBF16> layer_norm_0_bias;
|
||||
MatPtrT<WeightF32OrBF16> layer_norm_0_scale;
|
||||
MatPtrT<WeightF32OrBF16> layer_norm_1_bias;
|
||||
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
|
||||
MatPtr layer_norm_0_bias; // at least BF16.
|
||||
MatPtr layer_norm_0_scale; // at least BF16.
|
||||
MatPtr layer_norm_1_bias; // at least BF16.
|
||||
MatPtr layer_norm_1_scale; // at least BF16.
|
||||
} vit;
|
||||
|
||||
// Files either have gating_einsum_w with 2 stacked matrices or separate
|
||||
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
||||
MatPtrT<Weight> gating_einsum_w;
|
||||
MatPtrT<Weight> gating_einsum_w1;
|
||||
MatPtrT<Weight> gating_einsum_w2;
|
||||
MatPtrT<Weight> linear_w;
|
||||
// > W8 is likely helpful.
|
||||
MatPtrT<WeightF32OrBF16> pre_attention_norm_scale;
|
||||
MatPtrT<WeightF32OrBF16> pre_ffw_norm_scale;
|
||||
MatPtrT<WeightF32OrBF16> post_attention_norm_scale;
|
||||
MatPtrT<WeightF32OrBF16> post_ffw_norm_scale;
|
||||
// w1/w2 tensors. `Fixup` ensures w1/w2 are ready for use by gemma-inl.h.
|
||||
MatPtr gating_einsum_w;
|
||||
MatPtr gating_einsum_w1;
|
||||
MatPtr gating_einsum_w2;
|
||||
MatPtr linear_w;
|
||||
MatPtr pre_attention_norm_scale; // at least BF16.
|
||||
MatPtr pre_ffw_norm_scale; // at least BF16.
|
||||
MatPtr post_attention_norm_scale; // at least BF16.
|
||||
MatPtr post_ffw_norm_scale; // at least BF16.
|
||||
|
||||
MatPtrT<float> ffw_gating_biases;
|
||||
MatPtrT<float> ffw_output_biases;
|
||||
|
||||
MatPtrT<Weight> attn_vec_einsum_w; // Use att_weights instead of this.
|
||||
MatPtrT<Weight> att_weights; // Use this instead of attn_vec_einsum_w.
|
||||
MatPtr attn_vec_einsum_w; // Use att_weights instead of this.
|
||||
MatPtr att_weights; // Use this instead of attn_vec_einsum_w.
|
||||
|
||||
MatPtrT<WeightF32OrBF16> key_norm_scale;
|
||||
MatPtrT<WeightF32OrBF16> query_norm_scale;
|
||||
MatPtr key_norm_scale; // at least BF16.
|
||||
MatPtr query_norm_scale; // at least BF16.
|
||||
|
||||
const LayerConfig& layer_config;
|
||||
|
||||
|
|
@ -217,8 +212,8 @@ struct LayerWeightsPtrs {
|
|||
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
||||
// Public because also called by `WeightsPtrs`.
|
||||
template <class Func>
|
||||
void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
|
||||
LayerWeightsPtrs<Weight>* other2, Func func) {
|
||||
void ForEachTensor(LayerWeightsPtrs* other1, LayerWeightsPtrs* other2,
|
||||
Func func) {
|
||||
if (layer_config.type == LayerAttentionType::kVit) {
|
||||
// MHA.
|
||||
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
||||
|
|
@ -301,208 +296,98 @@ struct LayerWeightsPtrs {
|
|||
// Must be called after reading weights via `ForEachTensor`.
|
||||
// TODO: exporters should bake this into the weights already.
|
||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||
void Fixup(std::vector<MatOwner>& mat_owners) {
|
||||
InitAttWeights(mat_owners);
|
||||
SplitW1();
|
||||
SplitAttW1();
|
||||
}
|
||||
void Fixup(std::vector<MatOwner>& mat_owners);
|
||||
|
||||
private:
|
||||
// Copies att_weights from `attn_vec_einsum_w`.
|
||||
void InitAttWeights(std::vector<MatOwner>& mat_owners) {
|
||||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
// Files must have one or the other.
|
||||
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
|
||||
// Done if we already read the transposed tensor.
|
||||
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
|
||||
|
||||
// NUQ is handled by a specialization in weights.cc.
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
|
||||
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
|
||||
HWY_ASSERT(att_weights.Rows() == model_dim);
|
||||
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
|
||||
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
|
||||
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
|
||||
|
||||
{
|
||||
static std::mutex m;
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
mat_owners.push_back(MatOwner());
|
||||
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
|
||||
}
|
||||
|
||||
const size_t T_bytes = att_weights.ElementBytes();
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
uint8_t* HWY_RESTRICT out_row =
|
||||
reinterpret_cast<uint8_t*>(att_weights.Row(m));
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
hwy::CopyBytes(attn_vec_einsum_w.Row(h * model_dim + m),
|
||||
out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes);
|
||||
}
|
||||
}
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
void InitAttWeights(std::vector<MatOwner>& mat_owners);
|
||||
|
||||
// For FFN. Fast, only updates pointers.
|
||||
void SplitW1() {
|
||||
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
||||
if (layer_config.type == LayerAttentionType::kVit) return;
|
||||
|
||||
// Files have both or neither of w1 and w2.
|
||||
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||
// w is mutually exclusive with w1 and w2 in the file.
|
||||
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors. Note that they are not
|
||||
// necessarily the same type.
|
||||
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
|
||||
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
|
||||
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
|
||||
// Cols are the model_dim but we don't have ModelConfig here.
|
||||
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
|
||||
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
|
||||
|
||||
const size_t stride = gating_einsum_w.Stride();
|
||||
gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride);
|
||||
gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride);
|
||||
gating_einsum_w1.SetType(gating_einsum_w.GetType());
|
||||
gating_einsum_w2.SetType(gating_einsum_w.GetType());
|
||||
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
|
||||
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
|
||||
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
|
||||
}
|
||||
void SplitW1();
|
||||
|
||||
// For attention, which might not have a w2. Fast, only updates pointers.
|
||||
void SplitAttW1() {
|
||||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
// w is mutually exclusive with w1 in the file.
|
||||
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors. Note that w2 does not exist for
|
||||
// MHA, and otherwise might not be the same type.
|
||||
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
|
||||
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
|
||||
|
||||
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
|
||||
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
||||
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
|
||||
// Cols are the model_dim but we don't have ModelConfig here.
|
||||
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
|
||||
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
|
||||
|
||||
const size_t stride = qkv_einsum_w.Stride();
|
||||
qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride);
|
||||
qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride);
|
||||
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
|
||||
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
|
||||
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
|
||||
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
|
||||
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
||||
}
|
||||
void SplitAttW1();
|
||||
};
|
||||
|
||||
// Holds layer-independent weight metadata and pointers plus per-layer
|
||||
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. As with
|
||||
// `LayerWeightsPtrs`, this class could be type-erased: member functions do not
|
||||
// actually use the `Weight` template argument. The template does allow user
|
||||
// code to dispatch only once. However, most tensors are large enough that
|
||||
// dispatch at each usage would be feasible.
|
||||
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
|
||||
// TODO: move `gemma-inl.h` toward dispatch at each usage.
|
||||
// TODO: rename to WeightsPtrs.
|
||||
template <class Weight>
|
||||
struct ModelWeightsPtrs {
|
||||
using WeightT = Weight;
|
||||
|
||||
explicit ModelWeightsPtrs(const ModelConfig& config)
|
||||
: tensors_(config),
|
||||
// No suffix, these are per-model.
|
||||
embedder_input_embedding("c_embedding", tensors_),
|
||||
final_norm_scale("c_final_norm", tensors_),
|
||||
vit_encoder_norm_bias("enc_norm_bias", tensors_),
|
||||
vit_encoder_norm_scale("enc_norm_scale", tensors_),
|
||||
vit_img_embedding_bias("img_emb_bias", tensors_),
|
||||
vit_img_embedding_kernel("img_emb_kernel", tensors_),
|
||||
vit_img_pos_embedding("img_pos_emb", tensors_),
|
||||
vit_img_head_bias("img_head_bias", tensors_),
|
||||
vit_img_head_kernel("img_head_kernel", tensors_),
|
||||
mm_embed_norm("mm_embed_norm", tensors_),
|
||||
weights_config(config) {
|
||||
c_layers.reserve(config.layer_configs.size());
|
||||
for (size_t idx = 0; idx < config.layer_configs.size(); ++idx) {
|
||||
const LayerConfig& layer_config = config.layer_configs[idx];
|
||||
: config_(config),
|
||||
tensors_(config_),
|
||||
finder_("", tensors_), // no suffix because these are per-model.
|
||||
embedder_input_embedding(finder_("c_embedding")),
|
||||
final_norm_scale(finder_("c_final_norm")),
|
||||
vit_encoder_norm_bias(finder_("enc_norm_bias")),
|
||||
vit_encoder_norm_scale(finder_("enc_norm_scale")),
|
||||
vit_img_embedding_bias(finder_("img_emb_bias")),
|
||||
vit_img_embedding_kernel(finder_("img_emb_kernel")),
|
||||
vit_img_pos_embedding(finder_("img_pos_emb")),
|
||||
vit_img_head_bias(finder_("img_head_bias")),
|
||||
vit_img_head_kernel(finder_("img_head_kernel")),
|
||||
mm_embed_norm(finder_("mm_embed_norm")),
|
||||
c_layers() {
|
||||
c_layers.reserve(config_.layer_configs.size());
|
||||
for (size_t idx = 0; idx < config_.layer_configs.size(); ++idx) {
|
||||
const LayerConfig& layer_config = config_.layer_configs[idx];
|
||||
c_layers.emplace_back(idx, layer_config, tensors_);
|
||||
}
|
||||
for (size_t idx = 0; idx < config.vit_config.layer_configs.size(); ++idx) {
|
||||
const LayerConfig& layer_config = config.vit_config.layer_configs[idx];
|
||||
for (size_t idx = 0; idx < config_.vit_config.layer_configs.size(); ++idx) {
|
||||
const LayerConfig& layer_config = config_.vit_config.layer_configs[idx];
|
||||
vit_layers.emplace_back(idx, layer_config, tensors_);
|
||||
}
|
||||
}
|
||||
|
||||
~ModelWeightsPtrs() = default;
|
||||
// = F32 if weights are F32, else BF16.
|
||||
using WeightF32OrBF16 = typename LayerWeightsPtrs<Weight>::WeightF32OrBF16;
|
||||
|
||||
// Passed to all `MatPtrT` initializers, hence must be initialized first.
|
||||
const ModelConfig& config_;
|
||||
// Passed to finder_, hence must be initialized first.
|
||||
const TensorInfoRegistry tensors_;
|
||||
const MatFinder finder_;
|
||||
|
||||
// TODO: switch to SFP?
|
||||
MatPtrT<WeightF32OrBF16> embedder_input_embedding;
|
||||
MatPtrT<WeightF32OrBF16> final_norm_scale;
|
||||
MatPtr embedder_input_embedding;
|
||||
MatPtr final_norm_scale; // at least BF16.
|
||||
|
||||
// Vit parts.
|
||||
MatPtrT<WeightF32OrBF16> vit_encoder_norm_bias;
|
||||
MatPtrT<WeightF32OrBF16> vit_encoder_norm_scale;
|
||||
MatPtr vit_encoder_norm_bias; // at least BF16.
|
||||
MatPtr vit_encoder_norm_scale; // at least BF16.
|
||||
MatPtrT<float> vit_img_embedding_bias;
|
||||
MatPtrT<WeightF32OrBF16> vit_img_embedding_kernel;
|
||||
MatPtrT<float> vit_img_pos_embedding;
|
||||
MatPtr vit_img_embedding_kernel; // at least BF16.
|
||||
MatPtr vit_img_pos_embedding; // F32?
|
||||
// The head maps from VitConfig::model_dim (Vit final layer) to
|
||||
// model_dim (LLM input).
|
||||
MatPtrT<float> vit_img_head_bias;
|
||||
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
||||
MatPtr vit_img_head_kernel; // at least BF16.
|
||||
|
||||
MatPtrT<WeightF32OrBF16> mm_embed_norm;
|
||||
MatPtr mm_embed_norm; // at least BF16.
|
||||
|
||||
const ModelConfig& weights_config;
|
||||
std::vector<LayerWeightsPtrs> c_layers;
|
||||
std::vector<LayerWeightsPtrs> vit_layers;
|
||||
|
||||
std::vector<LayerWeightsPtrs<Weight>> c_layers;
|
||||
std::vector<LayerWeightsPtrs<Weight>> vit_layers;
|
||||
|
||||
const LayerWeightsPtrs<Weight>* GetLayer(size_t layer) const {
|
||||
const LayerWeightsPtrs* GetLayer(size_t layer) const {
|
||||
return &c_layers[layer];
|
||||
}
|
||||
LayerWeightsPtrs<Weight>* GetLayer(size_t layer) { return &c_layers[layer]; }
|
||||
const LayerWeightsPtrs<Weight>* VitLayer(size_t layer) const {
|
||||
return &vit_layers[layer];
|
||||
}
|
||||
LayerWeightsPtrs<Weight>* VitLayer(size_t layer) {
|
||||
LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; }
|
||||
const LayerWeightsPtrs* VitLayer(size_t layer) const {
|
||||
return &vit_layers[layer];
|
||||
}
|
||||
LayerWeightsPtrs* VitLayer(size_t layer) { return &vit_layers[layer]; }
|
||||
|
||||
// Called via `CallT`. `other1` and `other2` are usually null, but can be
|
||||
// used to copy from another set of weights. Public because called by tests
|
||||
// and `WeightsOwner`.
|
||||
template <class Func>
|
||||
void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
|
||||
ModelWeightsPtrs<Weight>* other2, Func func) {
|
||||
LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
||||
LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
||||
void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2,
|
||||
Func func) {
|
||||
LayerWeightsPtrs* other_layer1 = nullptr;
|
||||
LayerWeightsPtrs* other_layer2 = nullptr;
|
||||
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
||||
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
||||
|
||||
if (!weights_config.vit_config.layer_configs.empty()) { // Vit parts.
|
||||
if (!config_.vit_config.layer_configs.empty()) { // Vit parts.
|
||||
func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead));
|
||||
func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead));
|
||||
func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead));
|
||||
|
|
@ -511,7 +396,7 @@ struct ModelWeightsPtrs {
|
|||
func(TENSOR_ARGS(vit_img_head_bias, kMustRead));
|
||||
func(TENSOR_ARGS(vit_img_head_kernel, kMustRead));
|
||||
|
||||
if (weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
if (config_.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
func(TENSOR_ARGS(mm_embed_norm, kMustRead));
|
||||
}
|
||||
}
|
||||
|
|
@ -522,8 +407,7 @@ struct ModelWeightsPtrs {
|
|||
GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
||||
}
|
||||
|
||||
HWY_ASSERT(weights_config.vit_config.layer_configs.empty() ==
|
||||
vit_layers.empty());
|
||||
HWY_ASSERT(config_.vit_config.layer_configs.empty() == vit_layers.empty());
|
||||
for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) {
|
||||
HWY_ASSERT(vit_layers[layer_idx].layer_config.type ==
|
||||
LayerAttentionType::kVit);
|
||||
|
|
@ -534,87 +418,24 @@ struct ModelWeightsPtrs {
|
|||
} // `ForEachTensor`
|
||||
|
||||
// Zero-initializes only the allocated tensors in `*this`.
|
||||
void ZeroInit() {
|
||||
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
gcpp::ZeroInit(t.mat);
|
||||
});
|
||||
}
|
||||
|
||||
void ZeroInit();
|
||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
||||
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
|
||||
[](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
||||
CopyMat(*t.other_mat1, t.mat);
|
||||
});
|
||||
}
|
||||
|
||||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Must be called after reading and
|
||||
// updating the attention weights.
|
||||
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool) {
|
||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners);
|
||||
});
|
||||
|
||||
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners);
|
||||
});
|
||||
}
|
||||
}; // `WeightsPtrs`
|
||||
#undef TENSOR_ARGS
|
||||
|
||||
// Type-erased facade for `WeightsPtrs<T>`, stored inside the non-template
|
||||
// `Gemma`. Also owns the underlying memory.
|
||||
class WeightsOwner {
|
||||
public:
|
||||
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
||||
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
||||
void CopyFrom(const ModelWeightsPtrs& other);
|
||||
|
||||
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
||||
// override for whether to map blobs or read them.
|
||||
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
||||
// calls `ForEachTensor`.
|
||||
template <class Func, typename... TArgs>
|
||||
decltype(auto) CallT(const Func& func, TArgs&&... args) const {
|
||||
if (HWY_LIKELY(weight_type_ == Type::kSFP)) {
|
||||
return func(sfp_weights_, std::forward<TArgs>(args)...);
|
||||
} else if (weight_type_ == Type::kNUQ) {
|
||||
return func(nuq_weights_, std::forward<TArgs>(args)...);
|
||||
} else if (weight_type_ == Type::kBF16) {
|
||||
return func(bf16_weights_, std::forward<TArgs>(args)...);
|
||||
}
|
||||
return HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
|
||||
}
|
||||
|
||||
// For writers:
|
||||
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
||||
|
||||
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
||||
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
||||
|
||||
private:
|
||||
Type weight_type_;
|
||||
|
||||
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
||||
// of `CallT` so that can be const.
|
||||
void AllocatePointer(const ModelConfig& config);
|
||||
|
||||
// Called by `ReadFromBlobs`.
|
||||
void Fixup(hwy::ThreadPool& pool);
|
||||
|
||||
// Only one is non-null, determined by `weight_type_`.
|
||||
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
||||
|
||||
// Owns the memory referenced by all `MatPtr`.
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
};
|
||||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Called by ReadFromBlobs.
|
||||
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
||||
}; // `ModelWeightsPtrs`
|
||||
#undef TENSOR_ARGS
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ class BlobReader {
|
|||
const File& file() const { return *file_; }
|
||||
uint64_t file_bytes() const { return file_bytes_; }
|
||||
|
||||
void CloseFile() { file_.reset(); }
|
||||
|
||||
const std::vector<std::string>& Keys() const { return keys_; }
|
||||
|
||||
const BlobRange& Range(size_t key_idx) const {
|
||||
|
|
@ -99,7 +101,7 @@ class BlobReader {
|
|||
}
|
||||
|
||||
private:
|
||||
const std::unique_ptr<File> file_;
|
||||
std::unique_ptr<File> file_;
|
||||
const uint64_t file_bytes_;
|
||||
|
||||
std::vector<std::string> keys_;
|
||||
|
|
|
|||
63
util/mat.h
63
util/mat.h
|
|
@ -131,6 +131,13 @@ class MatPtr : public IFields {
|
|||
Type GetType() const { return type_; }
|
||||
void SetType(Type type) {
|
||||
type_ = type;
|
||||
if (type == Type::kUnknown) {
|
||||
// Temporary invalid state. Happens during weights.h construction, before
|
||||
// the ForEachTensor that loads them and sets the type.
|
||||
element_bytes_ = 0;
|
||||
num_elements_ = 0;
|
||||
return;
|
||||
}
|
||||
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
|
||||
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
|
||||
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
|
||||
|
|
@ -244,16 +251,18 @@ class MatPtrT : public MatPtr {
|
|||
// Called by `MatStorageT`.
|
||||
MatPtrT(const char* name, Extents2D extents)
|
||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||
// Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. This is
|
||||
// not a factory function because `weights.h` initializes members of type
|
||||
// `MatPtrT<T>`, and `T` cannot be inferred at compile time from arguments.
|
||||
MatPtrT(const std::string& name, const TensorInfoRegistry& info)
|
||||
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
|
||||
|
||||
// Copying allowed because the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {
|
||||
// Happens in weights.h when constructing via MatFinder, which does not
|
||||
// know the type. Setting the type here avoids having to keep the
|
||||
// initializer list and member type in sync.
|
||||
if (GetType() == Type::kUnknown) {
|
||||
SetType(TypeEnum<MatT>());
|
||||
} else {
|
||||
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
|
||||
}
|
||||
}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
MatPtr::operator=(other);
|
||||
return *this;
|
||||
|
|
@ -289,27 +298,27 @@ class MatPtrT : public MatPtr {
|
|||
}
|
||||
};
|
||||
|
||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||
// optional `args`. This supports all types used as weights.
|
||||
// Calls `func` with `MatPtrT<T>*` plus the optional `args`. This supports all
|
||||
// types used as weights.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
if (base->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<NuqStream> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // GEMMA_ENABLE_NUQ
|
||||
|
||||
if (base->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<float> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kBF16) {
|
||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<BF16> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kSFP) {
|
||||
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<SfpStream> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
|
|
@ -323,24 +332,24 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
|||
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
if (base1->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<NuqStream> mat1(*base1);
|
||||
const MatPtrT<NuqStream> mat2(*base2);
|
||||
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // GEMMA_ENABLE_NUQ
|
||||
|
||||
if (base1->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base1),
|
||||
dynamic_cast<const MatPtrT<float>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<float> mat1(*base1);
|
||||
const MatPtrT<float> mat2(*base2);
|
||||
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||
} else if (base1->GetType() == Type::kBF16) {
|
||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base1),
|
||||
dynamic_cast<const MatPtrT<BF16>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<BF16> mat1(*base1);
|
||||
const MatPtrT<BF16> mat2(*base2);
|
||||
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||
} else if (base1->GetType() == Type::kSFP) {
|
||||
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<SfpStream> mat1(*base1);
|
||||
const MatPtrT<SfpStream> mat2(*base2);
|
||||
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue