mirror of https://github.com/google/gemma.cpp.git
Further cleanup: separate MatMulEnv arg
move row_ptrs into MatMulEnv Consistent arg order: layer, activations, kv_cache, env PiperOrigin-RevId: 767886386
This commit is contained in:
parent
e774ddbaaa
commit
6ee628ba38
10
BUILD.bazel
10
BUILD.bazel
|
|
@ -242,13 +242,6 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = ["gemma/common.cc"],
|
||||
hdrs = ["gemma/common.h"],
|
||||
deps = [":configs"],
|
||||
)
|
||||
|
||||
# For building all tests in one command, so we can test several.
|
||||
test_suite(
|
||||
name = "ops_tests",
|
||||
|
|
@ -566,11 +559,10 @@ cc_library(
|
|||
},
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":ops",
|
||||
":matmul",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
|
|
|
|||
|
|
@ -55,8 +55,6 @@ set(SOURCES
|
|||
gemma/activations.h
|
||||
gemma/attention.cc
|
||||
gemma/attention.h
|
||||
gemma/common.cc
|
||||
gemma/common.h
|
||||
gemma/configs.cc
|
||||
gemma/configs.h
|
||||
gemma/gemma_args.h
|
||||
|
|
|
|||
|
|
@ -41,13 +41,13 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
|
|||
}
|
||||
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env)
|
||||
Activations(const ModelConfig& config, size_t batch_size,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
seq_len(config.seq_len),
|
||||
cache_pos_size(config.CachePosSize()),
|
||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||
query_scale(ChooseQueryScale(config)),
|
||||
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||
|
|
@ -96,19 +96,19 @@ struct Activations {
|
|||
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
||||
1000000.0)),
|
||||
|
||||
env(env) {
|
||||
query_scale(ChooseQueryScale(config)) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
x.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
q.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
x.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
|
||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
|
@ -141,7 +141,6 @@ struct Activations {
|
|||
size_t seq_len;
|
||||
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
|
||||
bool is_griffin = false;
|
||||
float query_scale;
|
||||
const Extents2D none_ = Extents2D();
|
||||
const MatPadding pad_ = MatPadding::kOdd;
|
||||
|
||||
|
|
@ -172,7 +171,7 @@ struct Activations {
|
|||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
MatMulEnv* env;
|
||||
float query_scale;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ static HWY_INLINE void ComputeQKV(
|
|||
size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, const int flags, NestedPools& pools) {
|
||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
const size_t num_queries = queries_pos.size();
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
|
|
@ -235,7 +235,7 @@ static HWY_INLINE void ComputeQKV(
|
|||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
|
||||
/*add=*/nullptr, *activations.env, activations.q);
|
||||
/*add=*/nullptr, env, activations.q);
|
||||
|
||||
// Set up MatMul row pointers for writing to KV, which consists of
|
||||
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
||||
|
|
@ -250,17 +250,16 @@ static HWY_INLINE void ComputeQKV(
|
|||
div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size + layer_idx * cache_layer_size;
|
||||
activations.env->storage.OutRow(interleaved_idx) =
|
||||
reinterpret_cast<uint8_t*>(kv_caches[query_idx].kv_cache.get() +
|
||||
kv_offset);
|
||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||
kv_caches[query_idx].kv_cache.get() + kv_offset);
|
||||
}
|
||||
kv_rows.AttachRowPtrs(&activations.env->storage.OutRow(0));
|
||||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
/*add=*/nullptr, *activations.env, kv_rows);
|
||||
/*add=*/nullptr, env, kv_rows);
|
||||
|
||||
// Apply positional encodings for K.
|
||||
// TODO: 2D parallelism to use more threads.
|
||||
pools.Pool(0).Run(
|
||||
env.ctx.pools.Pool(0).Run(
|
||||
0, kv_heads * num_interleaved,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
|
|
@ -289,7 +288,7 @@ static HWY_INLINE void ComputeQKV(
|
|||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
Activations& activations) {
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
|
|
@ -302,7 +301,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
const float* add = layer_config.softmax_attn_output_biases
|
||||
? layer.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
CallMatMul(activations.att_out, layer.att_weights, add, *activations.env,
|
||||
CallMatMul(activations.att_out, layer.att_weights, add, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
|
|
@ -312,7 +311,7 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
|||
const QueriesPos* queries_prefix_end,
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, int flags) {
|
||||
const KVCaches& kv_caches, MatMulEnv& env, int flags) {
|
||||
const size_t num_queries = queries_pos.size();
|
||||
HWY_DASSERT(num_queries <= kv_caches.size());
|
||||
|
||||
|
|
@ -331,13 +330,12 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
|||
queries_prefix_end = &queries_prefix_end_span;
|
||||
}
|
||||
|
||||
NestedPools& pools = activations.env->ctx.pools;
|
||||
ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer,
|
||||
activations, kv_caches, flags, pools);
|
||||
activations, kv_caches, flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end,
|
||||
div_seq_len, layer_idx, layer, activations, kv_caches,
|
||||
pools);
|
||||
SumHeads(layer, activations);
|
||||
env.ctx.pools);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ namespace gcpp {
|
|||
const QueriesPos* queries_prefix_end, \
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, Activations& activations, \
|
||||
const KVCaches& kv_caches, int flags); \
|
||||
const KVCaches& kv_caches, MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@
|
|||
#include <stdio.h>
|
||||
#endif
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/common.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt) {
|
||||
|
||||
// Instruction-tuned models are trained to expect control tokens.
|
||||
if (config.wrapping == PromptWrapping::GEMMA_IT) {
|
||||
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
||||
const std::string start = (pos == 0)
|
||||
? "<start_of_turn>user\n"
|
||||
: "<end_of_turn>\n<start_of_turn>user\n";
|
||||
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Wraps the given prompt using the expected control tokens for IT models.
|
||||
// DEPRECATED, use WrapAndTokenize instead if a tokenized return value is fine.
|
||||
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
|
@ -18,7 +18,10 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -103,8 +106,8 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
|||
}
|
||||
}
|
||||
|
||||
static inline void FFWNoVit(Activations& activations,
|
||||
const LayerWeightsPtrs& layer) {
|
||||
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
||||
|
|
@ -117,16 +120,16 @@ static inline void FFWNoVit(Activations& activations,
|
|||
add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr;
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1,
|
||||
*activations.env, activations.C1);
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2,
|
||||
*activations.env, activations.C2);
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env,
|
||||
activations.C1);
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env,
|
||||
activations.C2);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.linear_w, output_bias, *activations.env,
|
||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
||||
activations.ffw_out);
|
||||
}
|
||||
|
||||
|
|
|
|||
121
gemma/gemma.cc
121
gemma/gemma.cc
|
|
@ -68,10 +68,10 @@ void Attention(LayerAttentionType type, size_t num_tokens,
|
|||
const QueriesPos& queries_prefix_end,
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches) {
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len,
|
||||
layer_idx, layer, activations, kv_caches,
|
||||
layer_idx, layer, activations, kv_caches, env,
|
||||
/*flags=*/0);
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
|
|
@ -80,7 +80,7 @@ void Attention(LayerAttentionType type, size_t num_tokens,
|
|||
const size_t griffin_layer =
|
||||
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
|
||||
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
|
||||
&layer, kv_caches);
|
||||
&layer, kv_caches, env);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -88,14 +88,14 @@ static HWY_NOINLINE void TransformerLayer(
|
|||
const size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches) {
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||
activations.pre_att_rms_out);
|
||||
|
||||
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
|
||||
div_seq_len, layer_idx, layer, activations, kv_caches);
|
||||
div_seq_len, layer_idx, layer, activations, kv_caches, env);
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||
activations.att_sums);
|
||||
|
|
@ -107,9 +107,9 @@ static HWY_NOINLINE void TransformerLayer(
|
|||
activations.pre_ffw_rms_out);
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kVit) {
|
||||
FFWVit(activations, layer);
|
||||
FFWVit(layer, activations, env);
|
||||
} else {
|
||||
FFWNoVit(activations, layer);
|
||||
FFWNoVit(layer, activations, env);
|
||||
}
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
|
||||
|
|
@ -187,12 +187,11 @@ using QueriesMutablePos = hwy::Span<size_t>;
|
|||
|
||||
// Populates KV cache for batches of tokens from one query at a time.
|
||||
static HWY_NOINLINE void Prefill(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
const hwy::Divisor& div_seq_len, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
|
|
@ -263,7 +262,7 @@ static HWY_NOINLINE void Prefill(
|
|||
++layer_idx) {
|
||||
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
|
||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, single_kv_cache);
|
||||
activations, single_kv_cache, env);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
|
|
@ -298,9 +297,9 @@ static HWY_NOINLINE void Prefill(
|
|||
// from each query, and `queries_pos` are their position in the sequence.
|
||||
static HWY_NOINLINE void Transformer(
|
||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||
const LayersOutputFunc& layers_output,
|
||||
const ActivationsObserverFunc& activations_observer) {
|
||||
const size_t num_queries = queries_token.size();
|
||||
|
|
@ -323,7 +322,7 @@ static HWY_NOINLINE void Transformer(
|
|||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
|
||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, kv_caches);
|
||||
activations, kv_caches, env);
|
||||
|
||||
if (activations_observer) {
|
||||
activations_observer(queries_pos, layer_idx, activations);
|
||||
|
|
@ -381,23 +380,22 @@ class TokenStreamer {
|
|||
|
||||
// Runs one decode step for all the queries in the batch. Returns true if all
|
||||
// queries are at <end_of_sentence>.
|
||||
static bool DecodeStepT(const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
static bool DecodeStepT(
|
||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config, const size_t query_idx_start,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const hwy::Divisor div_seq_len, const size_t vocab_size,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, TokenStreamer& token_streamer,
|
||||
std::vector<int>& gen_tokens, TimingInfo& timing_info,
|
||||
const QueriesMutablePos& queries_mutable_pos) {
|
||||
const QueriesMutablePos& queries_mutable_pos,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
|
||||
const size_t vocab_size, const SampleFunc& sample_token,
|
||||
Activations& activations, const KVCaches& kv_caches,
|
||||
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
|
||||
TimingInfo& timing_info, MatMulEnv& env) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
// Decode generates one token per query and increments
|
||||
// queries_mutable_pos.
|
||||
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
||||
queries_prefix_end, config, weights, activations, div_seq_len,
|
||||
kv_caches, runtime_config.layers_output,
|
||||
queries_prefix_end, div_seq_len, config, weights, activations,
|
||||
kv_caches, env, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
|
|
@ -407,7 +405,7 @@ static bool DecodeStepT(const ModelConfig& config,
|
|||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, *activations.env, activations.logits);
|
||||
/*add=*/nullptr, env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
@ -468,14 +466,12 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
|
|||
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||
//
|
||||
// `kv_caches` is for the batch, size must match `queries_prompt`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
static void GenerateT(
|
||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config, const size_t query_idx_start,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in,
|
||||
const QueriesPos& queries_prefix_end, Activations& activations,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info, MatMulEnv& env) {
|
||||
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
|
||||
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
|
|
@ -512,9 +508,9 @@ static void GenerateT(const ModelConfig& config,
|
|||
// token is the first input token for generation.
|
||||
timing_info.prefill_start = hwy::platform::Now();
|
||||
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
|
||||
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
query_idx_start, config, weights, activations, runtime_config,
|
||||
div_seq_len, kv_caches);
|
||||
Prefill(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, div_seq_len, config, runtime_config, weights,
|
||||
activations, kv_caches, env);
|
||||
// Compute the number of tokens that were prefilled and notify timing_info.
|
||||
size_t prefilled_tokens = 0;
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
|
|
@ -543,11 +539,12 @@ static void GenerateT(const ModelConfig& config,
|
|||
const size_t vocab_size = config.vocab_size;
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||
bool all_queries_eos = DecodeStepT(
|
||||
config, weights, runtime_config, queries_prompt, query_idx_start,
|
||||
kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token,
|
||||
activations, token_streamer, gen_tokens, timing_info,
|
||||
queries_mutable_pos);
|
||||
bool all_queries_eos =
|
||||
DecodeStepT(config, weights, runtime_config, query_idx_start,
|
||||
queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
div_seq_len, vocab_size, sample_token, activations,
|
||||
kv_caches, token_streamer, gen_tokens, timing_info, env);
|
||||
|
||||
if (all_queries_eos) break;
|
||||
} // foreach token to generate
|
||||
timing_info.NotifyGenerateDone();
|
||||
|
|
@ -557,7 +554,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
|
@ -565,16 +562,16 @@ void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
|||
const size_t max_batch_size =
|
||||
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
|
||||
// TODO: move into Gemma?
|
||||
Activations activations(config, max_batch_size, env);
|
||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
||||
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT(config, weights, activations, runtime_config, queries_prompt,
|
||||
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
|
||||
timing_info);
|
||||
GenerateT(config, weights, runtime_config, qbatch_start, queries_prompt,
|
||||
queries_pos, queries_prefix_end, activations, kv_caches,
|
||||
timing_info, env);
|
||||
}
|
||||
|
||||
void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||
|
|
@ -582,7 +579,7 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
|||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, MatMulEnv* env,
|
||||
const KVCaches& kv_caches, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
|
|
@ -591,7 +588,7 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
|||
const size_t max_batch_size =
|
||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||
|
||||
Activations activations(config, max_batch_size, env);
|
||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
|
|
@ -604,9 +601,9 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
|||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT(config, weights, activations, runtime_config, qbatch_prompts,
|
||||
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
|
||||
timing_info);
|
||||
GenerateT(config, weights, runtime_config, qbatch_start, qbatch_prompts,
|
||||
qbatch_pos, qbatch_prefix_end, activations, qbatch_kv,
|
||||
timing_info, env);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -614,7 +611,7 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
MatMulEnv* env) {
|
||||
MatMulEnv& env) {
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
|
|
@ -622,10 +619,10 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
ModelConfig vit_config = GetVitConfig(config);
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, vit_config.seq_len, env);
|
||||
Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations);
|
||||
prefill_activations, env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
@ -677,7 +674,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_,
|
||||
runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, &env_, timing_info);
|
||||
kv_cache, env_, timing_info);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
@ -701,7 +698,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
|
||||
model_.Config(), weights_, runtime_config, queries_prompt, queries_pos,
|
||||
mutable_queries_prefix_end, kv_caches, &env_, timing_info);
|
||||
mutable_queries_prefix_end, kv_caches, env_, timing_info);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
@ -712,7 +709,7 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
|||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
|
||||
model_.Config(), weights_, runtime_config, image, image_tokens, &env_);
|
||||
model_.Config(), weights_, runtime_config, image, image_tokens, env_);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,9 +48,9 @@ namespace HWY_NAMESPACE {
|
|||
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
||||
size_t griffin_layer, Activations& activations,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const KVCaches& kv_caches) {
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D df;
|
||||
|
|
@ -183,8 +183,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
|||
|
||||
// Final linear layer.
|
||||
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(),
|
||||
*activations.env, activations.att_sums);
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
|
||||
activations.att_sums);
|
||||
} // GriffinRecurrent
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ namespace gcpp {
|
|||
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \
|
||||
size_t griffin_layer, Activations& activations, \
|
||||
const LayerWeightsPtrs* layer_weights, \
|
||||
const KVCaches& kv_caches); \
|
||||
const KVCaches& kv_caches, MatMulEnv& env); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
47
gemma/vit.cc
47
gemma/vit.cc
|
|
@ -60,7 +60,7 @@ class VitAttention {
|
|||
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w,
|
||||
layer_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, qkv);
|
||||
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
|
|
@ -100,7 +100,7 @@ class VitAttention {
|
|||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
CallMatMul(Q, K, nullptr, *activations_.env, C);
|
||||
CallMatMul(Q, K, nullptr, env_, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
|
|
@ -166,18 +166,19 @@ class VitAttention {
|
|||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias,
|
||||
*activations_.env, activations_.att_sums);
|
||||
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_,
|
||||
activations_.att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations,
|
||||
const LayerWeightsPtrs& layer)
|
||||
const LayerWeightsPtrs& layer, MatMulEnv& env)
|
||||
: num_tokens_(num_tokens),
|
||||
activations_(activations),
|
||||
layer_(layer),
|
||||
layer_config_(layer.layer_config),
|
||||
pool_(activations.env->ctx.pools.Pool(0)) {}
|
||||
env_(env),
|
||||
pool_(env_.ctx.pools.Pool(0)) {}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
|
|
@ -194,12 +195,14 @@ class VitAttention {
|
|||
Activations& activations_;
|
||||
const LayerWeightsPtrs& layer_;
|
||||
const LayerConfig& layer_config_;
|
||||
MatMulEnv& env_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
// Same as FFWNoVit, but with different layer members and no second
|
||||
// gating matrix.
|
||||
void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) {
|
||||
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.FFW.ViT");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
|
|
@ -209,15 +212,15 @@ void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) {
|
|||
add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr;
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1,
|
||||
*activations.env, activations.C1);
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, env,
|
||||
activations.C1);
|
||||
|
||||
// Activation (Gelu), store in C1.
|
||||
ActivationBatched(layer_config.activation, activations.C1);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias,
|
||||
*activations.env, activations.ffw_out);
|
||||
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
||||
activations.ffw_out);
|
||||
}
|
||||
|
||||
// Vit transformer layer. Some comments below refer to the Vit implementation in
|
||||
|
|
@ -227,7 +230,7 @@ void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) {
|
|||
// try merging this with TransformerLayer.
|
||||
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations) {
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
const size_t model_dim = activations.weights_config.model_dim;
|
||||
auto type = layer.layer_config.type;
|
||||
HWY_DASSERT(type == LayerAttentionType::kVit);
|
||||
|
|
@ -245,7 +248,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
// y ~ att_sums
|
||||
VitAttention(num_tokens, layer_idx, activations, layer)();
|
||||
VitAttention(num_tokens, layer_idx, activations, layer, env)();
|
||||
|
||||
// x = out["+sa"] = x + y
|
||||
AddFromBatched(activations.att_sums, x);
|
||||
|
|
@ -257,7 +260,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
// y = out["mlp"] = MlpBlock(...)(y)
|
||||
// y ~ ffw_out
|
||||
FFWVit(activations, layer);
|
||||
FFWVit(layer, activations, env);
|
||||
|
||||
// x = out["+mlp"] = x + y
|
||||
AddFromBatched(activations.ffw_out, x);
|
||||
|
|
@ -268,7 +271,8 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||
const ModelConfig& model_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
Activations& activations) {
|
||||
Activations& activations,
|
||||
MatMulEnv& env) {
|
||||
const size_t model_dim = model_config.vit_config.model_dim;
|
||||
const size_t patch_width = model_config.vit_config.patch_width;
|
||||
const size_t seq_len = model_config.vit_config.seq_len;
|
||||
|
|
@ -287,8 +291,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
image.GetPatch(i, image_patches.Row(i));
|
||||
}
|
||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||
weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
activations.x);
|
||||
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
|
||||
// Add position embeddings.
|
||||
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
||||
[&](const auto* weights_t) {
|
||||
|
|
@ -300,18 +303,19 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
void PrefillVit(const ModelConfig& model_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, Activations& activations) {
|
||||
ImageTokens& image_tokens, Activations& activations,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.PrefillVit");
|
||||
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||
const size_t vit_model_dim = model_config.vit_config.model_dim;
|
||||
HWY_ASSERT(num_tokens == activations.x.Rows());
|
||||
// Embed the image patches.
|
||||
EmbedImagePatches(image, model_config, weights, activations);
|
||||
EmbedImagePatches(image, model_config, weights, activations, env);
|
||||
// Go through all layers.
|
||||
for (size_t layer_idx = 0;
|
||||
layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) {
|
||||
VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx),
|
||||
activations);
|
||||
activations, env);
|
||||
}
|
||||
// Final Layernorm.
|
||||
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
|
||||
|
|
@ -329,8 +333,7 @@ void PrefillVit(const ModelConfig& model_config,
|
|||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
CallMatMul(activations.x, weights.vit_img_head_kernel,
|
||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||
image_tokens);
|
||||
weights.vit_img_head_bias.PackedScale1(), env, image_tokens);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -28,12 +28,14 @@ namespace gcpp {
|
|||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void FFWVit(Activations& activations, const LayerWeightsPtrs& layer); \
|
||||
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
|
||||
MatMulEnv& env); \
|
||||
\
|
||||
void PrefillVit(const ModelConfig& model_config, \
|
||||
const ModelWeightsPtrs& weights, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, Activations& activations); \
|
||||
ImageTokens& image_tokens, Activations& activations, \
|
||||
MatMulEnv& env); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -756,7 +756,7 @@ class DotStats {
|
|||
|
||||
// Compensated and Double are very accurate.
|
||||
ASSERT_LESS(kCompensated, s_muls[kCompensated].Min(), 1.0f + 2E-6f);
|
||||
ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 2E-5f);
|
||||
ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 1E-4f);
|
||||
ASSERT_LESS(kDouble, s_muls[kDouble].Min(), 1.0f + 2E-6f);
|
||||
ASSERT_LESS(kDouble, s_muls[kDouble].Max(), 1.0f + 2E-5f);
|
||||
|
||||
|
|
|
|||
|
|
@ -1324,9 +1324,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
}
|
||||
HWY_DASSERT(C.HasPtr());
|
||||
for (size_t r = 0; r < C.Rows(); ++r) {
|
||||
env.storage.OutRow(r) = reinterpret_cast<uint8_t*>(C.Row(r));
|
||||
env.row_ptrs[0][r] = reinterpret_cast<uint8_t*>(C.Row(r));
|
||||
}
|
||||
C_rows = CRows<TC>(&env.storage.OutRow(0));
|
||||
C_rows = CRows<TC>(env.row_ptrs[0].get());
|
||||
}
|
||||
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
|
|
|
|||
|
|
@ -427,6 +427,8 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
|||
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
|
||||
char cpu100[100];
|
||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM));
|
||||
}
|
||||
|
||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||
|
|
|
|||
24
ops/matmul.h
24
ops/matmul.h
|
|
@ -231,10 +231,9 @@ class MMStorage {
|
|||
// Internally threaded; must not be called concurrently with the same
|
||||
// `ThreadingContext` (used via `parallel`).
|
||||
MMStorage(const Allocator& allocator, MMParallel& parallel)
|
||||
: out_rows(hwy::AllocateAligned<uint8_t*>(kMaxM)),
|
||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||
// one instance of the maximum matrix extents because threads write at
|
||||
// false-sharing-free granularity.
|
||||
: // Per-worker copies of `partial` would be wasteful. We instead
|
||||
// allocate one instance of the maximum matrix extents because threads
|
||||
// write at false-sharing-free granularity.
|
||||
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
||||
MatPadding::kOdd),
|
||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
|
|
@ -260,8 +259,6 @@ class MMStorage {
|
|||
BindC(partial_storage_, parallel);
|
||||
}
|
||||
|
||||
uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; }
|
||||
|
||||
// Returns per-package matrix view.
|
||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
|
|
@ -273,11 +270,6 @@ class MMStorage {
|
|||
RowPtrD Partial() const { return partial_; }
|
||||
|
||||
private:
|
||||
// Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a
|
||||
// constant stride, but GemmaAttention::ComputeQKV writes to differing KV
|
||||
// positions per query / output row. `kMaxM` elements are too large for the
|
||||
// stack, hence dynamic allocation.
|
||||
hwy::AlignedFreeUniquePtr<uint8_t*[]> out_rows;
|
||||
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
||||
MatStorageT<double> partial_storage_;
|
||||
RowPtrD partial_;
|
||||
|
|
@ -667,9 +659,13 @@ struct MatMulEnv {
|
|||
MMKeys keys;
|
||||
std::vector<MMPerKey> per_key;
|
||||
|
||||
// Pass to MatPtr::AllocateAndAttachRowPtrs.
|
||||
// Per-tensor allocations to make it likelier that asan detects bugs such as
|
||||
// use after free, overrun, and dangling references.
|
||||
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
|
||||
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
|
||||
// writes to differing KV positions per query / output row.
|
||||
// The first entry is sufficient for any MatMul, but also potentially
|
||||
// overwritten by each MatMul. Subsequent entries are precomputed for tensors
|
||||
// and not overwritten. Per-tensor allocations make it likelier that asan
|
||||
// detects bugs such as use after free, overrun, and dangling references.
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue