mirror of https://github.com/google/gemma.cpp.git
Fixed msan error by fixing padding of k_cache and v_cache
PiperOrigin-RevId: 879644219
This commit is contained in:
parent
d6c7576024
commit
d2806fb1dd
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "util/zones.h"
|
||||
#include "hwy/base.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -58,8 +59,8 @@ size_t FloatsPerVector() {
|
|||
|
||||
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
|
||||
// done already, reshape it to take NF into account.
|
||||
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
|
||||
if (kv.Cols() > cache.Cols()) {
|
||||
void MaybeReshapeCache(const size_t default_cols, MatPtrT<KV_t>& cache) {
|
||||
if (default_cols == cache.Cols()) {
|
||||
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
|
||||
}
|
||||
}
|
||||
|
|
@ -71,13 +72,50 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
|
|||
// is a tiny fraction of the overall computation, and it is linear in the
|
||||
// token length.
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
|
||||
for (size_t i = 0; i < qkv_dim; i += 2) {
|
||||
k[i * kFloatsPerTile] = kv[i];
|
||||
k[i * kFloatsPerTile + 1] = kv[i + 1];
|
||||
}
|
||||
for (size_t i = qkv_dim; i < kRoundedQkvDim; i += 2) {
|
||||
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
}
|
||||
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
|
||||
if (i + kFloatsPerTile <= qkv_dim) {
|
||||
for (size_t j = 0; j < kFloatsPerTile; j++) {
|
||||
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < qkv_dim - i; j++) {
|
||||
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
|
||||
}
|
||||
for (size_t j = qkv_dim - i; j < kFloatsPerTile; j++) {
|
||||
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = hwy::RoundUpTo(qkv_dim, kFloatsPerTile); i < kRoundedQkvDim;
|
||||
i += kFloatsPerTile) {
|
||||
for (size_t j = 0; j < kFloatsPerTile; j++) {
|
||||
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
|
||||
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Zeros out a part of k and v that corresponds to out-of-bounds cache
|
||||
// positions.
|
||||
void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v,
|
||||
size_t qkv_dim) {
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
|
||||
for (size_t i = 0; i < kRoundedQkvDim; i += 2) {
|
||||
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
}
|
||||
for (size_t i = 0; i < kRoundedQkvDim; i += kFloatsPerTile) {
|
||||
for (size_t j = 0; j < kFloatsPerTile; j++) {
|
||||
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -314,16 +352,22 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
/*add=*/nullptr, env, kv_rows);
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
|
||||
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
|
||||
qbatch.KV(qi).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
|
||||
qbatch.KV(qi).v_cache);
|
||||
}
|
||||
const size_t kFloatsPerVector = FloatsPerVector();
|
||||
const size_t kRoundedTokens =
|
||||
hwy::RoundUpTo(num_tokens, 2 * kFloatsPerVector);
|
||||
const size_t kRoundedNumInterleaved =
|
||||
kRoundedTokens * div_qbatch.GetDivisor();
|
||||
|
||||
// Apply positional encodings for K.
|
||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
// tasks are very lightweight.
|
||||
ParallelFor(
|
||||
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||
Parallelism::kFlat, kv_heads * kRoundedNumInterleaved, env.ctx,
|
||||
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
||||
[&](size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
|
|
@ -331,6 +375,28 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
|
||||
if (token_idx >= kRoundedTokens) {
|
||||
return;
|
||||
}
|
||||
// The innermost dimension of v is 2NF values from qkv_dim because they
|
||||
// will be loaded into a BF16 vector to be scaled and added to the
|
||||
// cached attention output in 2 NF-sized registers.
|
||||
auto& k_cache = qbatch.KV(qi).k_cache;
|
||||
KV_t* HWY_RESTRICT k =
|
||||
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
qbatch.KV(qi).cache->KOffset(layer_idx, head, kFloatsPerVector,
|
||||
cache_pos);
|
||||
auto& v_cache = qbatch.KV(qi).v_cache;
|
||||
KV_t* HWY_RESTRICT v =
|
||||
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
qbatch.KV(qi).cache->VOffset(layer_idx, head, kFloatsPerVector,
|
||||
cache_pos);
|
||||
if (token_idx >= num_tokens) {
|
||||
// Create a zero-filled K/V pair for padding for out-of-sequence
|
||||
// tokens.
|
||||
TransposeOOBKVCacheRow(k, v, qkv_dim);
|
||||
return;
|
||||
}
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
|
|
@ -341,22 +407,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// The innermost dimension of k is 2 values from qkv_dim because they
|
||||
// are going to be used in a BF16 dot product involving pairs of
|
||||
// values over NF k positions.
|
||||
// The innermost dimension of v is 2NF values from qkv_dim because they
|
||||
// will be loaded into a BF16 vector to be scaled and added to the
|
||||
// cached attention output in 2 NF-sized registers.
|
||||
// TODO(rays): factor out these calculations into functions.
|
||||
auto& k_cache = qbatch.KV(qi).k_cache;
|
||||
KV_t* HWY_RESTRICT k =
|
||||
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
|
||||
kFloatsPerVector +
|
||||
(cache_pos % (2 * kFloatsPerVector)) * 2;
|
||||
auto& v_cache = qbatch.KV(qi).v_cache;
|
||||
KV_t* HWY_RESTRICT v =
|
||||
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
|
||||
kFloatsPerVector +
|
||||
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
|
||||
|
||||
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ namespace gcpp {
|
|||
namespace NAMESPACE { \
|
||||
size_t FloatsPerVector(); \
|
||||
\
|
||||
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
|
||||
void MaybeReshapeCache(size_t default_cols, MatPtrT<KV_t>& cache); \
|
||||
\
|
||||
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
|
||||
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
|
||||
|
|
|
|||
|
|
@ -29,9 +29,12 @@
|
|||
#include "io/fields.h" // IFieldsVisitor
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16);
|
||||
|
||||
HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
|
||||
|
||||
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
|
||||
|
|
|
|||
|
|
@ -1700,7 +1700,6 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
|
|||
// A "head group" in the context of GQA refers to a collection of query
|
||||
// heads that share the same key and value heads.
|
||||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
|
||||
const size_t total_tasks = token_batch * layer_config.heads;
|
||||
size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks,
|
||||
|
|
@ -1716,11 +1715,9 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
|
|||
params.clear();
|
||||
for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) {
|
||||
for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) {
|
||||
const size_t head_offset = kv_head * qkv_dim * 2;
|
||||
const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset;
|
||||
params.push_back(Tile148Params{
|
||||
.qi_index = qi,
|
||||
.kv_offset = kv_offset,
|
||||
.kv_head = kv_head,
|
||||
});
|
||||
for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
|
|
@ -1746,7 +1743,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
|
|||
// current tile is full so start new tile.
|
||||
params.push_back(Tile148Params{
|
||||
.qi_index = qi,
|
||||
.kv_offset = kv_offset,
|
||||
.kv_head = kv_head,
|
||||
});
|
||||
}
|
||||
const size_t head = head_group + kHeadGroups * kv_head;
|
||||
|
|
@ -2157,13 +2154,20 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
|
||||
auto& param = params[task];
|
||||
auto& kT_cache = qbatch.KV(param.qi_index).k_cache;
|
||||
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
|
||||
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
|
||||
qkv_dim * 2 * kNF));
|
||||
kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride());
|
||||
kRoundedQkvDim * 2 * kNF));
|
||||
kT.SetPtr(
|
||||
kT_cache.Row(0) + qbatch.KV(param.qi_index)
|
||||
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
|
||||
kT_cache.Stride());
|
||||
auto& vT_cache = qbatch.KV(param.qi_index).v_cache;
|
||||
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
|
||||
qkv_dim * 2 * kNF));
|
||||
vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride());
|
||||
kRoundedQkvDim * 2 * kNF));
|
||||
vT.SetPtr(
|
||||
vT_cache.Row(0) + qbatch.KV(param.qi_index)
|
||||
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
|
||||
vT_cache.Stride());
|
||||
MatPtrT<float>& att_out =
|
||||
param.i_of_n == 0 ? activations.att_out : activations.att_out_reps;
|
||||
DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx,
|
||||
|
|
|
|||
|
|
@ -144,10 +144,15 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(attention.div_seq_len.GetDivisor());
|
||||
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
|
||||
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
|
||||
qbatch.KV(0).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
|
||||
qbatch.KV(0).v_cache);
|
||||
auto& kvc = qbatch.KV(0).kv_cache;
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
const size_t kFloatsPerTile = 2 * kNF;
|
||||
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||
// Make strided views into the kv cache for
|
||||
// this query and head.
|
||||
|
|
@ -160,12 +165,12 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
SetMat(h + layer_config.heads * 2, v);
|
||||
for (size_t p = 0; p < tokens.size(); ++p) {
|
||||
KV_t* HWY_RESTRICT k_src = k.Row(p);
|
||||
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
|
||||
head_offset * kFloatsPerTile / 2 +
|
||||
p % kFloatsPerTile * 2;
|
||||
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
|
||||
head_offset * kFloatsPerTile / 2 +
|
||||
p % kFloatsPerTile * kFloatsPerTile;
|
||||
KV_t* HWY_RESTRICT k_dest =
|
||||
qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
|
||||
qbatch.KV(0).cache->KOffset(0, h / kHeadGroups, kNF, p);
|
||||
KV_t* HWY_RESTRICT v_dest =
|
||||
qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
|
||||
qbatch.KV(0).cache->VOffset(0, h / kHeadGroups, kNF, p);
|
||||
|
||||
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
|
||||
}
|
||||
|
|
@ -176,9 +181,6 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
// Copy the output to saved_att to allow for comparison.
|
||||
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||
SetMat(1, attention.q);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
const size_t total_tasks =
|
||||
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
|
||||
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@ struct Tile148Params {
|
|||
uint32_t max_last_pos = 0;
|
||||
// Index into the qbatch.KV is the same for each row in the tile.
|
||||
uint32_t qi_index;
|
||||
// Index into the kv_cache is the same for each row in the tile.
|
||||
uint32_t kv_offset;
|
||||
// kv_head is the same for each row in the tile.
|
||||
uint32_t kv_head;
|
||||
// In the original task, the index to the split tasks of the first split task.
|
||||
uint32_t split_index = 0;
|
||||
// The index of the split for running split attention.
|
||||
|
|
|
|||
|
|
@ -29,11 +29,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// TODO: rays - Remove this once hwy is updated.
|
||||
#ifndef HWY_ARCH_MAX_BYTES
|
||||
#define HWY_ARCH_MAX_BYTES 256
|
||||
#endif
|
||||
|
||||
// Number of rows for KV cache. Note that both rows and cols are u32, and
|
||||
// the total number of elements can exceed 2^32.
|
||||
static size_t CappedSeqLen(const ModelConfig& config,
|
||||
|
|
@ -46,8 +41,13 @@ static size_t CappedSeqLen(const ModelConfig& config,
|
|||
return inference_args.seq_len;
|
||||
}
|
||||
|
||||
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
|
||||
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers,
|
||||
size_t kv_heads, size_t qkv_dim, const Allocator& allocator)
|
||||
: num_layers(num_layers),
|
||||
kv_heads(kv_heads),
|
||||
qkv_dim(qkv_dim),
|
||||
rounded_qkv_dim(hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector)),
|
||||
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
// WARNING: the rows and cols of k_cache and v_cache will be modified
|
||||
// before use!
|
||||
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
|
||||
|
|
@ -56,14 +56,12 @@ KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
|
|||
// machine architecture, since kFloatsPerVector is architecture dependent.
|
||||
// The change is shape is safe only if the padding is kPacked.
|
||||
k_cache("k",
|
||||
Extents2D(HWY_MAX(kv_extents.rows,
|
||||
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
|
||||
kv_extents.cols / 2),
|
||||
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
|
||||
KOrVDefaultCols()),
|
||||
allocator, MatPadding::kPacked),
|
||||
v_cache("v",
|
||||
Extents2D(HWY_MAX(kv_extents.rows,
|
||||
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
|
||||
kv_extents.cols / 2),
|
||||
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
|
||||
KOrVDefaultCols()),
|
||||
allocator, MatPadding::kPacked),
|
||||
allocator_(allocator) {}
|
||||
|
||||
|
|
@ -71,7 +69,8 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
const Allocator& allocator)
|
||||
: KVCache(
|
||||
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
|
||||
allocator) {}
|
||||
config.layer_configs.size(), config.layer_configs[0].kv_heads,
|
||||
config.layer_configs[0].qkv_dim, allocator) {}
|
||||
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const RuntimeConfig& runtime_config,
|
||||
|
|
@ -135,7 +134,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
}
|
||||
|
||||
KVCache KVCache::Copy() {
|
||||
KVCache copy(kv_cache.Extents(), allocator_);
|
||||
KVCache copy(kv_cache.Extents(), num_layers, kv_heads, qkv_dim, allocator_);
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
return copy;
|
||||
|
|
|
|||
|
|
@ -91,6 +91,38 @@ struct KVCache {
|
|||
return {start_ptr, source_ptr};
|
||||
}
|
||||
|
||||
// Returns the default size of a row in k_cache or v_cache, before scaling by
|
||||
// 2 * kNF.
|
||||
size_t KOrVDefaultCols() const {
|
||||
return num_layers * kv_heads * rounded_qkv_dim;
|
||||
}
|
||||
|
||||
// Returns an offset into a row of k_cache or v_cache at a position that is
|
||||
// aligned to the tile size (a multiple of 2kNF).
|
||||
size_t KOrVOffset(const size_t layer_idx, const size_t kv_head_idx,
|
||||
const size_t kNF) const {
|
||||
return (layer_idx * kv_heads + kv_head_idx) * rounded_qkv_dim * 2 * kNF;
|
||||
}
|
||||
|
||||
// Returns an offset into k_cache at any given position.
|
||||
size_t KOffset(const size_t layer_idx, const size_t kv_head_idx,
|
||||
const size_t kNF, const size_t pos) const {
|
||||
return KOrVOffset(layer_idx, kv_head_idx, kNF) + (pos % (2 * kNF)) * 2;
|
||||
}
|
||||
|
||||
// Returns an offset into v_cache at any given position.
|
||||
size_t VOffset(const size_t layer_idx, const size_t kv_head_idx,
|
||||
const size_t kNF, const size_t pos) const {
|
||||
return KOrVOffset(layer_idx, kv_head_idx, kNF) +
|
||||
(pos % (2 * kNF)) * 2 * kNF;
|
||||
}
|
||||
|
||||
// Saved sizes for computing offsets into the KV cache.
|
||||
size_t num_layers = 0;
|
||||
size_t kv_heads = 0;
|
||||
size_t qkv_dim = 0;
|
||||
size_t rounded_qkv_dim = 0;
|
||||
|
||||
static constexpr size_t kTileSize = 32;
|
||||
std::optional<uint32_t> tiled_seq_len = std::nullopt;
|
||||
// Default Format
|
||||
|
|
@ -159,7 +191,8 @@ struct KVCache {
|
|||
const Allocator& allocator_;
|
||||
|
||||
// For use by other ctor and Copy()
|
||||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||
KVCache(const Extents2D& kv_extents, size_t num_layers, size_t kv_heads,
|
||||
size_t qkv_dim, const Allocator& allocator);
|
||||
};
|
||||
|
||||
inline size_t KVCachePtr::SeqLen() const {
|
||||
|
|
|
|||
Loading…
Reference in New Issue