mirror of https://github.com/google/gemma.cpp.git
Introduce attention implementation configurability.
PiperOrigin-RevId: 828971705
This commit is contained in:
parent
091b4567c9
commit
35e9f9f05f
|
|
@ -524,6 +524,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":basics",
|
":basics",
|
||||||
|
":configs",
|
||||||
":mat",
|
":mat",
|
||||||
"//io",
|
"//io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
|
#include "gemma/gemma_args.h" // AttentionImpl
|
||||||
#include "ops/ops.h" // CreateInvTimescale
|
#include "ops/ops.h" // CreateInvTimescale
|
||||||
#include "util/basics.h" // BF16
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h" // MatStorageT
|
#include "util/mat.h" // MatStorageT
|
||||||
|
|
@ -179,8 +180,8 @@ struct AttentionActivationsPtrs {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Activations {
|
struct Activations {
|
||||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
Activations(const RuntimeConfig& runtime_config, const ModelConfig& config,
|
||||||
ThreadingContext& ctx,
|
size_t batch_size, size_t seq_len, ThreadingContext& ctx,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: layer_config(config.layer_configs[0]),
|
: layer_config(config.layer_configs[0]),
|
||||||
|
|
||||||
|
|
@ -199,6 +200,7 @@ struct Activations {
|
||||||
ffw_out(
|
ffw_out(
|
||||||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||||
|
|
||||||
|
attention_impl(runtime_config.attention_impl),
|
||||||
attention_storage(config, layer_config, batch_size, seq_len,
|
attention_storage(config, layer_config, batch_size, seq_len,
|
||||||
ctx.allocator, row_ptrs),
|
ctx.allocator, row_ptrs),
|
||||||
attention(config, seq_len, attention_storage) {
|
attention(config, seq_len, attention_storage) {
|
||||||
|
|
@ -248,6 +250,8 @@ struct Activations {
|
||||||
MatStorageT<BF16> C2;
|
MatStorageT<BF16> C2;
|
||||||
MatStorageT<float> ffw_out;
|
MatStorageT<float> ffw_out;
|
||||||
|
|
||||||
|
AttentionImpl attention_impl;
|
||||||
|
|
||||||
AttentionActivations attention_storage;
|
AttentionActivations attention_storage;
|
||||||
AttentionActivationsPtrs attention;
|
AttentionActivationsPtrs attention;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,34 @@ static inline bool EnumValid(LayerAttentionType type) {
|
||||||
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
|
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum class AttentionImpl {
|
||||||
|
kOld,
|
||||||
|
kFlash,
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Returns a bitmask of flags to pass to attention functions based on the
|
||||||
|
* attention implementation selected.
|
||||||
|
*
|
||||||
|
* If `hwy_native_dot_bf16` is true, the function will use the old attention
|
||||||
|
* implementation, ignoring `impl`.
|
||||||
|
*
|
||||||
|
* `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16
|
||||||
|
* macro is not available outside of highway instrumented translation units and
|
||||||
|
* cannot be made accessible from .h files.
|
||||||
|
*/
|
||||||
|
static inline int AttentionImplToFlags(AttentionImpl impl,
|
||||||
|
int hwy_native_dot_bf16) {
|
||||||
|
if (hwy_native_dot_bf16) return kAttentionUseOld;
|
||||||
|
|
||||||
|
switch (impl) {
|
||||||
|
case AttentionImpl::kOld:
|
||||||
|
return kAttentionUseOld;
|
||||||
|
case AttentionImpl::kFlash:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Post attention and ffw normalization type.
|
// Post attention and ffw normalization type.
|
||||||
enum class PostNormType {
|
enum class PostNormType {
|
||||||
None,
|
None,
|
||||||
|
|
|
||||||
|
|
@ -112,8 +112,9 @@ void TestFlashAttention(size_t target_parallelism) {
|
||||||
RuntimeConfig runtime_config;
|
RuntimeConfig runtime_config;
|
||||||
KVCache kv_cache(config, inference_args, ctx.allocator);
|
KVCache kv_cache(config, inference_args, ctx.allocator);
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
Activations activations(runtime_config, config,
|
||||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
|
||||||
|
env.ctx, env.row_ptrs);
|
||||||
std::vector<int> tokens(kOuter);
|
std::vector<int> tokens(kOuter);
|
||||||
std::iota(tokens.begin(), tokens.end(), 1);
|
std::iota(tokens.begin(), tokens.end(), 1);
|
||||||
PromptTokens prompt(tokens);
|
PromptTokens prompt(tokens);
|
||||||
|
|
|
||||||
|
|
@ -73,10 +73,12 @@ namespace HWY_NAMESPACE {
|
||||||
void Attention(LayerAttentionType type, const size_t num_tokens,
|
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||||
|
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
// TODO: remove flag to enable FlashAttention.
|
// TODO: remove flag to enable FlashAttention.
|
||||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
GemmaAttention(
|
||||||
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0);
|
num_tokens, layer_idx, layer, activations.attention, qbatch, env,
|
||||||
|
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -573,8 +575,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||||
KVCache& kv_cache, MatMulEnv& env,
|
KVCache& kv_cache, MatMulEnv& env,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
Activations activations(runtime_config, config,
|
||||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
|
||||||
|
env.ctx, env.row_ptrs);
|
||||||
|
|
||||||
AllQueries all_queries(prompt, pos, prefix_end,
|
AllQueries all_queries(prompt, pos, prefix_end,
|
||||||
hwy::Span<KVCache>(&kv_cache, 1));
|
hwy::Span<KVCache>(&kv_cache, 1));
|
||||||
|
|
@ -592,7 +595,7 @@ void GenerateBatchT(const ModelConfig& config,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||||
runtime_config.prefill_tbatch_size);
|
runtime_config.prefill_tbatch_size);
|
||||||
Activations activations(config, max_batch_size,
|
Activations activations(runtime_config, config, max_batch_size,
|
||||||
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
||||||
env.row_ptrs);
|
env.row_ptrs);
|
||||||
|
|
||||||
|
|
@ -617,8 +620,8 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const size_t num_tokens = vit_config.max_seq_len;
|
const size_t num_tokens = vit_config.max_seq_len;
|
||||||
prefill_runtime_config.prefill_tbatch_size =
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx,
|
Activations prefill_activations(runtime_config, vit_config, num_tokens,
|
||||||
env.row_ptrs);
|
num_tokens, env.ctx, env.row_ptrs);
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||||
prefill_activations, env);
|
prefill_activations, env);
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
|
|
@ -139,6 +140,9 @@ struct RuntimeConfig {
|
||||||
|
|
||||||
int verbosity; // Controls verbosity of printed messages.
|
int verbosity; // Controls verbosity of printed messages.
|
||||||
|
|
||||||
|
// Which attention implementation to use.
|
||||||
|
AttentionImpl attention_impl = AttentionImpl::kFlash;
|
||||||
|
|
||||||
// Functions operating on the generated tokens.
|
// Functions operating on the generated tokens.
|
||||||
StreamFunc stream_token;
|
StreamFunc stream_token;
|
||||||
BatchStreamFunc batch_stream_token;
|
BatchStreamFunc batch_stream_token;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue