mirror of https://github.com/google/gemma.cpp.git
Temporarily disable flash pending msan fix
PiperOrigin-RevId: 805350234
This commit is contained in:
parent
ba6131311a
commit
2695aab5d2
|
|
@ -48,9 +48,6 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use.
|
|
||||||
constexpr int kUseOldAttention = 2;
|
|
||||||
|
|
||||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
|
|
@ -357,7 +354,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
(void)layer_config; // only used in HWY_DASSERT
|
(void)layer_config; // only used in HWY_DASSERT
|
||||||
|
|
||||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||||
if (flags & kUseOldAttention) {
|
if (flags & kAttentionUseOld) {
|
||||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||||
env.ctx);
|
env.ctx);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,9 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr size_t kMaxConv1DWidth = 4;
|
HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
|
||||||
static constexpr size_t kMaxQKVDim = 1024;
|
|
||||||
|
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
|
||||||
|
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
enum class PromptWrapping {
|
enum class PromptWrapping {
|
||||||
|
|
|
||||||
|
|
@ -73,9 +73,9 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
|
// TODO: remove flag to enable FlashAttention.
|
||||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||||
env,
|
env, kAttentionUseOld);
|
||||||
/*flags=*/0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue