Temporarily disable flash pending msan fix

PiperOrigin-RevId: 805350234
This commit is contained in:
Jan Wassenberg 2025-09-10 07:25:07 -07:00 committed by Copybara-Service
parent ba6131311a
commit 2695aab5d2
3 changed files with 6 additions and 8 deletions

View File

@ -48,9 +48,6 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use.
constexpr int kUseOldAttention = 2;
// Computes Q.K scores, which are "logits" (or scores) stored to att. // Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
@ -357,7 +354,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
(void)layer_config; // only used in HWY_DASSERT (void)layer_config; // only used in HWY_DASSERT
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
if (flags & kUseOldAttention) { if (flags & kAttentionUseOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx); env.ctx);
} else { } else {

View File

@ -32,8 +32,9 @@
namespace gcpp { namespace gcpp {
static constexpr size_t kMaxConv1DWidth = 4; HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
static constexpr size_t kMaxQKVDim = 1024;
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
// Instruction-tuned models require extra 'turn structure' tokens in prompts. // Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping { enum class PromptWrapping {

View File

@ -73,9 +73,9 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) { Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention.
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env, env, kAttentionUseOld);
/*flags=*/0);
} }
} }