Fixed msan error by fixing padding of k_cache and v_cache

PiperOrigin-RevId: 880060015
This commit is contained in:
The gemma.cpp Authors 2026-03-07 03:17:44 -08:00 committed by Copybara-Service
parent d2806fb1dd
commit be511554a9
8 changed files with 62 additions and 153 deletions

View File

@ -20,7 +20,6 @@
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
@ -59,8 +58,8 @@ size_t FloatsPerVector() {
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
// done already, reshape it to take NF into account.
void MaybeReshapeCache(const size_t default_cols, MatPtrT<KV_t>& cache) {
if (default_cols == cache.Cols()) {
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
if (kv.Cols() > cache.Cols()) {
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
}
}
@ -72,50 +71,13 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
// is a tiny fraction of the overall computation, and it is linear in the
// token length.
const size_t kFloatsPerTile = 2 * FloatsPerVector();
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
for (size_t i = 0; i < qkv_dim; i += 2) {
k[i * kFloatsPerTile] = kv[i];
k[i * kFloatsPerTile + 1] = kv[i + 1];
}
for (size_t i = qkv_dim; i < kRoundedQkvDim; i += 2) {
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
if (i + kFloatsPerTile <= qkv_dim) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
} else {
for (size_t j = 0; j < qkv_dim - i; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
for (size_t j = qkv_dim - i; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
}
}
for (size_t i = hwy::RoundUpTo(qkv_dim, kFloatsPerTile); i < kRoundedQkvDim;
i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
}
}
// Zeros out a part of k and v that corresponds to out-of-bounds cache
// positions.
void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v,
size_t qkv_dim) {
const size_t kFloatsPerTile = 2 * FloatsPerVector();
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
for (size_t i = 0; i < kRoundedQkvDim; i += 2) {
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
for (size_t i = 0; i < kRoundedQkvDim; i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
}
}
@ -352,22 +314,16 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
qbatch.KV(qi).v_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
}
const size_t kFloatsPerVector = FloatsPerVector();
const size_t kRoundedTokens =
hwy::RoundUpTo(num_tokens, 2 * kFloatsPerVector);
const size_t kRoundedNumInterleaved =
kRoundedTokens * div_qbatch.GetDivisor();
// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight.
ParallelFor(
Parallelism::kFlat, kv_heads * kRoundedNumInterleaved, env.ctx,
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads;
@ -375,28 +331,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
if (token_idx >= kRoundedTokens) {
return;
}
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
qbatch.KV(qi).cache->KOffset(layer_idx, head, kFloatsPerVector,
cache_pos);
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
qbatch.KV(qi).cache->VOffset(layer_idx, head, kFloatsPerVector,
cache_pos);
if (token_idx >= num_tokens) {
// Create a zero-filled K/V pair for padding for out-of-sequence
// tokens.
TransposeOOBKVCacheRow(k, v, qkv_dim);
return;
}
// --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen());
auto& kv_cache = qbatch.KV(qi).kv_cache;
@ -407,6 +341,22 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// The innermost dimension of k is 2 values from qkv_dim because they
// are going to be used in a BF16 dot product involving pairs of
// values over NF k positions.
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
// TODO(rays): factor out these calculations into functions.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2;
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;

View File

@ -33,7 +33,7 @@ namespace gcpp {
namespace NAMESPACE { \
size_t FloatsPerVector(); \
\
void MaybeReshapeCache(size_t default_cols, MatPtrT<KV_t>& cache); \
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
\
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
KV_t* HWY_RESTRICT v, size_t qkv_dim); \

View File

@ -29,12 +29,9 @@
#include "io/fields.h" // IFieldsVisitor
#include "io/io.h" // Path
#include "util/basics.h"
#include "hwy/detect_compiler_arch.h"
namespace gcpp {
constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16);
HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;

View File

@ -1700,6 +1700,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
const size_t total_tasks = token_batch * layer_config.heads;
size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks,
@ -1715,9 +1716,11 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
params.clear();
for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) {
for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) {
const size_t head_offset = kv_head * qkv_dim * 2;
const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset;
params.push_back(Tile148Params{
.qi_index = qi,
.kv_head = kv_head,
.kv_offset = kv_offset,
});
for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = qbatch.Pos(qi) + batch_idx;
@ -1743,7 +1746,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
// current tile is full so start new tile.
params.push_back(Tile148Params{
.qi_index = qi,
.kv_head = kv_head,
.kv_offset = kv_offset,
});
}
const size_t head = head_group + kHeadGroups * kv_head;
@ -2154,20 +2157,13 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
auto& param = params[task];
auto& kT_cache = qbatch.KV(param.qi_index).k_cache;
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
kRoundedQkvDim * 2 * kNF));
kT.SetPtr(
kT_cache.Row(0) + qbatch.KV(param.qi_index)
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
kT_cache.Stride());
qkv_dim * 2 * kNF));
kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride());
auto& vT_cache = qbatch.KV(param.qi_index).v_cache;
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
kRoundedQkvDim * 2 * kNF));
vT.SetPtr(
vT_cache.Row(0) + qbatch.KV(param.qi_index)
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
vT_cache.Stride());
qkv_dim * 2 * kNF));
vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride());
MatPtrT<float>& att_out =
param.i_of_n == 0 ? activations.att_out : activations.att_out_reps;
DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx,

View File

@ -144,15 +144,10 @@ void TestFlashAttention(size_t target_parallelism,
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
qbatch.KV(0).v_cache);
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
auto& kvc = qbatch.KV(0).kv_cache;
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t kFloatsPerTile = 2 * kNF;
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
@ -165,12 +160,12 @@ void TestFlashAttention(size_t target_parallelism,
SetMat(h + layer_config.heads * 2, v);
for (size_t p = 0; p < tokens.size(); ++p) {
KV_t* HWY_RESTRICT k_src = k.Row(p);
KV_t* HWY_RESTRICT k_dest =
qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
qbatch.KV(0).cache->KOffset(0, h / kHeadGroups, kNF, p);
KV_t* HWY_RESTRICT v_dest =
qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
qbatch.KV(0).cache->VOffset(0, h / kHeadGroups, kNF, p);
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * 2;
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * kFloatsPerTile;
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
}
@ -181,6 +176,9 @@ void TestFlashAttention(size_t target_parallelism,
// Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q);
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t total_tasks =
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),

View File

@ -48,8 +48,8 @@ struct Tile148Params {
uint32_t max_last_pos = 0;
// Index into the qbatch.KV is the same for each row in the tile.
uint32_t qi_index;
// kv_head is the same for each row in the tile.
uint32_t kv_head;
// Index into the kv_cache is the same for each row in the tile.
uint32_t kv_offset;
// In the original task, the index to the split tasks of the first split task.
uint32_t split_index = 0;
// The index of the split for running split attention.

View File

@ -29,6 +29,11 @@
namespace gcpp {
// TODO: rays - Remove this once hwy is updated.
#ifndef HWY_ARCH_MAX_BYTES
#define HWY_ARCH_MAX_BYTES 256
#endif
// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
@ -41,13 +46,8 @@ static size_t CappedSeqLen(const ModelConfig& config,
return inference_args.seq_len;
}
KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers,
size_t kv_heads, size_t qkv_dim, const Allocator& allocator)
: num_layers(num_layers),
kv_heads(kv_heads),
qkv_dim(qkv_dim),
rounded_qkv_dim(hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector)),
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
// WARNING: the rows and cols of k_cache and v_cache will be modified
// before use!
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
@ -56,12 +56,14 @@ KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers,
// machine architecture, since kFloatsPerVector is architecture dependent.
// The change is shape is safe only if the padding is kPacked.
k_cache("k",
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
KOrVDefaultCols()),
Extents2D(HWY_MAX(kv_extents.rows,
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
kv_extents.cols / 2),
allocator, MatPadding::kPacked),
v_cache("v",
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
KOrVDefaultCols()),
Extents2D(HWY_MAX(kv_extents.rows,
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
kv_extents.cols / 2),
allocator, MatPadding::kPacked),
allocator_(allocator) {}
@ -69,8 +71,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator)
: KVCache(
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
config.layer_configs.size(), config.layer_configs[0].kv_heads,
config.layer_configs[0].qkv_dim, allocator) {}
allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const RuntimeConfig& runtime_config,
@ -134,7 +135,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
}
KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), num_layers, kv_heads, qkv_dim, allocator_);
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
return copy;

View File

@ -91,38 +91,6 @@ struct KVCache {
return {start_ptr, source_ptr};
}
// Returns the default size of a row in k_cache or v_cache, before scaling by
// 2 * kNF.
size_t KOrVDefaultCols() const {
return num_layers * kv_heads * rounded_qkv_dim;
}
// Returns an offset into a row of k_cache or v_cache at a position that is
// aligned to the tile size (a multiple of 2kNF).
size_t KOrVOffset(const size_t layer_idx, const size_t kv_head_idx,
const size_t kNF) const {
return (layer_idx * kv_heads + kv_head_idx) * rounded_qkv_dim * 2 * kNF;
}
// Returns an offset into k_cache at any given position.
size_t KOffset(const size_t layer_idx, const size_t kv_head_idx,
const size_t kNF, const size_t pos) const {
return KOrVOffset(layer_idx, kv_head_idx, kNF) + (pos % (2 * kNF)) * 2;
}
// Returns an offset into v_cache at any given position.
size_t VOffset(const size_t layer_idx, const size_t kv_head_idx,
const size_t kNF, const size_t pos) const {
return KOrVOffset(layer_idx, kv_head_idx, kNF) +
(pos % (2 * kNF)) * 2 * kNF;
}
// Saved sizes for computing offsets into the KV cache.
size_t num_layers = 0;
size_t kv_heads = 0;
size_t qkv_dim = 0;
size_t rounded_qkv_dim = 0;
static constexpr size_t kTileSize = 32;
std::optional<uint32_t> tiled_seq_len = std::nullopt;
// Default Format
@ -191,8 +159,7 @@ struct KVCache {
const Allocator& allocator_;
// For use by other ctor and Copy()
KVCache(const Extents2D& kv_extents, size_t num_layers, size_t kv_heads,
size_t qkv_dim, const Allocator& allocator);
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
};
inline size_t KVCachePtr::SeqLen() const {