Internal changes

PiperOrigin-RevId: 837775282
This commit is contained in:
Krzysztof Rymski 2025-11-28 02:36:36 -08:00 committed by Copybara-Service
parent 3c9e6cf113
commit 6e5e4123f1
3 changed files with 37 additions and 32 deletions

View File

@ -17,13 +17,16 @@
#include <stdint.h> #include <stdint.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <limits> #include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h" #include "gemma/flash_structs.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h" #include "util/zones.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS

View File

@ -37,7 +37,7 @@ namespace gcpp {
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\ \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \ const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \ size_t layer_idx, \
const AttentionActivationsPtrs& activations, \ const AttentionActivationsPtrs& activations, \
@ -60,6 +60,7 @@ namespace gcpp {
size_t layer_idx, const MatPtr& query_norm_scale, \ size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -25,6 +25,7 @@
#include <cstdint> #include <cstdint>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <utility>
#include <vector> #include <vector>
#include "ops/matmul.h" #include "ops/matmul.h"