Internal changes

PiperOrigin-RevId: 837775282
This commit is contained in:
Krzysztof Rymski 2025-11-28 02:36:36 -08:00 committed by Copybara-Service
parent 3c9e6cf113
commit 6e5e4123f1
3 changed files with 37 additions and 32 deletions

View File

@ -17,13 +17,16 @@
#include <stdint.h> #include <stdint.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <limits> #include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h" #include "gemma/flash_structs.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h" #include "util/zones.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS

View File

@ -29,38 +29,39 @@
namespace gcpp { namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \ void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \ size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtr& query_norm_scale, size_t layer_idx, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\ \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \ const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \ size_t layer_idx, \
const AttentionActivationsPtrs& activations, \ const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \ float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \ ThreadingContext& ctx, size_t worker); \
\ \
Tile4FlashState TileFlashAttention4( \ Tile4FlashState TileFlashAttention4( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \ const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \ const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \ size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx, const size_t worker); \ ThreadingContext& ctx, const size_t worker); \
\ \
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \
\ \
void FlashAttention(size_t num_tokens, size_t target_parallelism, \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \ size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the // Function declarations for each SIMD target. Allows direct call from the

View File

@ -25,6 +25,7 @@
#include <cstdint> #include <cstdint>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <utility>
#include <vector> #include <vector>
#include "ops/matmul.h" #include "ops/matmul.h"