Utilities to convert between different encodings of kv cache

PiperOrigin-RevId: 885553004
This commit is contained in:
Krzysztof Rymski 2026-03-18 06:15:57 -07:00 committed by Copybara-Service
parent 0110ddfee7
commit 1a5226e5de
5 changed files with 780 additions and 0 deletions

View File

@ -542,6 +542,31 @@ cc_test(
],
)
cc_library(
name = "kv_transcoding",
srcs = ["gemma/kv_transcoding.cc"],
hdrs = ["gemma/kv_transcoding.h"],
deps = [
":activations",
":basics",
":configs",
":kv_cache",
"//compression:types",
"@highway//:hwy",
],
)
cc_test(
name = "kv_transcoding_test",
srcs = ["gemma/kv_transcoding_test.cc"],
deps = [
":configs",
":kv_transcoding",
"//testing/base/public:gunit_main",
"@highway//:hwy",
],
)
cc_library(
name = "activations",
hdrs = ["gemma/activations.h"],

View File

@ -83,6 +83,17 @@ static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
}
// Values stated explicitly to allow for semantic reordering
enum class KVEncoding {
kUnspecified = 0,
kF32 = 1,
kBF16 = 2,
kF32TwoTranspositions = 3,
kBF16TwoTranspositions = 4,
kInt8 = 5,
kInt8TwoTranspositions = 6,
};
enum class AttentionImpl {
kOld, // Previous Attention implementation
kFlash, // Flash Attention (default)

314
gemma/kv_transcoding.cc Normal file
View File

@ -0,0 +1,314 @@
#include "gemma/kv_transcoding.h"
#include <algorithm>
#include <cstddef>
#include <cstdlib>
#include <optional>
#include "compression/types.h"
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/kv_cache.h"
#include "util/basics.h"
#include "hwy/base.h"
#include "hwy/highway.h"
namespace gcpp {
std::optional<size_t> GetTileSizeBytes(gcpp::KVEncoding encoding,
size_t qkv_dim) {
constexpr size_t kTileSize = gcpp::KVCache::kTileSize;
switch (encoding) {
case gcpp::KVEncoding::kInt8:
case gcpp::KVEncoding::kInt8TwoTranspositions:
return qkv_dim * kTileSize * 2 * sizeof(int8_t) +
kTileSize * 2 * sizeof(gcpp::KV_microscale_t);
case gcpp::KVEncoding::kBF16:
case gcpp::KVEncoding::kBF16TwoTranspositions:
return qkv_dim * kTileSize * 2 * sizeof(gcpp::BF16);
case gcpp::KVEncoding::kF32:
case gcpp::KVEncoding::kF32TwoTranspositions:
return qkv_dim * kTileSize * 2 * sizeof(float);
default:
return std::nullopt;
}
}
namespace {
constexpr size_t kTileSize = gcpp::KVCache::kTileSize;
inline size_t KOffset(bool transposed, size_t qkv_dim, size_t dim,
size_t token) {
HWY_DASSERT(dim < qkv_dim && token < kTileSize);
return transposed ? ((dim / 2) * kTileSize * 2 + token * 2 + (dim % 2))
: (dim * kTileSize + token);
}
inline size_t VOffset(bool transposed, size_t qkv_dim, size_t dim,
size_t token) {
HWY_DASSERT(dim < qkv_dim && token < kTileSize);
return transposed ? ((token / 2) * qkv_dim * 2 + dim * 2 + (token % 2))
: (token * qkv_dim + dim);
}
int8_t Quantize(float v, float inv_scale) {
float scaled = v * inv_scale;
if (scaled > 127.0f) return 127;
if (scaled < -127.0f) return -127;
return hwy::ConvertScalarTo<int8_t>(scaled);
}
template <typename DecodeKFn, typename DecodeVFn>
inline void DecodeTileWithFn(size_t qkv_dim, DecodedTile* out,
const DecodeKFn& decode_k,
const DecodeVFn& decode_v) {
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
out->k_elem(token, dim) = decode_k(dim, token);
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
out->v_elem(token, dim) = decode_v(dim, token);
}
}
}
template <typename EncodeKFn, typename EncodeVFn>
inline void EncodeTileWithFn(size_t qkv_dim, const DecodedTile& decoded,
const EncodeKFn& encode_k,
const EncodeVFn& encode_v) {
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
encode_k(dim, token, decoded.k_elem(token, dim));
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
encode_v(dim, token, decoded.v_elem(token, dim));
}
}
}
void EncodeTileF32(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
hwy::Span<char> out_encoded_tile_data) {
float* data = HWY_RCAST_ALIGNED(float*, out_encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
EncodeTileWithFn(
qkv_dim, decoded,
[&](size_t dim, size_t token, float val)
HWY_ATTR { data[KOffset(transposed, qkv_dim, dim, token)] = val; },
[&](size_t dim, size_t token, float val) HWY_ATTR {
data[v_start + VOffset(transposed, qkv_dim, dim, token)] = val;
});
}
void EncodeTileBF16(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
hwy::Span<char> out_encoded_tile_data) {
gcpp::BF16* data =
HWY_RCAST_ALIGNED(gcpp::BF16*, out_encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
EncodeTileWithFn(
qkv_dim, decoded,
[&](size_t dim, size_t token, float val) HWY_ATTR {
data[KOffset(transposed, qkv_dim, dim, token)] =
hwy::ConvertScalarTo<hwy::bfloat16_t>(val);
},
[&](size_t dim, size_t token, float val) HWY_ATTR {
data[v_start + VOffset(transposed, qkv_dim, dim, token)] =
hwy::ConvertScalarTo<hwy::bfloat16_t>(val);
});
}
void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
hwy::Span<char> out_encoded_tile_data) {
int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data());
int8_t* v_data = k_data + qkv_dim * kTileSize;
gcpp::KV_microscale_t* scales =
HWY_RCAST_ALIGNED(gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim);
gcpp::KV_microscale_t* k_scales = scales;
gcpp::KV_microscale_t* v_scales = scales + kTileSize;
AlignedFloatVector k_max_abs(kTileSize, 0.0f);
AlignedFloatVector v_max_abs(kTileSize, 0.0f);
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
k_max_abs[token] =
std::max(k_max_abs[token], std::abs(decoded.k_elem(token, dim)));
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
v_max_abs[token] =
std::max(v_max_abs[token], std::abs(decoded.v_elem(token, dim)));
}
}
AlignedFloatVector inv_scales_k(kTileSize);
AlignedFloatVector inv_scales_v(kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
float scale_k = k_max_abs[token] == 0.0f ? 1.0f : k_max_abs[token] / 127.0f;
k_scales[token] = hwy::ConvertScalarTo<gcpp::KV_microscale_t>(scale_k);
inv_scales_k[token] = 1.0f / scale_k;
float scale_v = v_max_abs[token] == 0.0f ? 1.0f : v_max_abs[token] / 127.0f;
v_scales[token] = hwy::ConvertScalarTo<gcpp::KV_microscale_t>(scale_v);
inv_scales_v[token] = 1.0f / scale_v;
}
EncodeTileWithFn(
qkv_dim, decoded,
[&](size_t dim, size_t token, float val) HWY_ATTR {
k_data[KOffset(transposed, qkv_dim, dim, token)] =
Quantize(val, inv_scales_k[token]);
},
[&](size_t dim, size_t token, float val) HWY_ATTR {
v_data[VOffset(transposed, qkv_dim, dim, token)] =
Quantize(val, inv_scales_v[token]);
});
}
void DecodeTileF32(bool transposed, size_t qkv_dim,
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
const float* data = HWY_RCAST_ALIGNED(const float*, encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
DecodeTileWithFn(
qkv_dim, out,
[&](size_t dim, size_t token)
HWY_ATTR { return data[KOffset(transposed, qkv_dim, dim, token)]; },
[&](size_t dim, size_t token) HWY_ATTR {
return data[v_start + VOffset(transposed, qkv_dim, dim, token)];
});
}
void DecodeTileBF16(bool transposed, size_t qkv_dim,
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
const gcpp::BF16* data =
HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
DecodeTileWithFn(
qkv_dim, out,
[&](size_t dim, size_t token) HWY_ATTR {
return hwy::ConvertScalarTo<float>(
data[KOffset(transposed, qkv_dim, dim, token)]);
},
[&](size_t dim, size_t token) HWY_ATTR {
return hwy::ConvertScalarTo<float>(
data[v_start + VOffset(transposed, qkv_dim, dim, token)]);
});
}
void DecodeTileInt8(bool transposed, size_t qkv_dim,
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
const int8_t* k_data =
HWY_RCAST_ALIGNED(const int8_t*, encoded_tile_data.data());
const int8_t* v_data = k_data + qkv_dim * kTileSize;
const gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED(
const gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim);
const gcpp::KV_microscale_t* k_scales = scales;
const gcpp::KV_microscale_t* v_scales = scales + kTileSize;
DecodeTileWithFn(
qkv_dim, out,
[&](size_t dim, size_t token) HWY_ATTR {
float scale = hwy::ConvertScalarTo<float>(k_scales[token]);
return k_data[KOffset(transposed, qkv_dim, dim, token)] * scale;
},
[&](size_t dim, size_t token) HWY_ATTR {
float scale = hwy::ConvertScalarTo<float>(v_scales[token]);
return v_data[VOffset(transposed, qkv_dim, dim, token)] * scale;
});
}
} // namespace
bool IsTransposed(KVEncoding encoding) {
switch (encoding) {
case KVEncoding::kF32TwoTranspositions:
case KVEncoding::kBF16TwoTranspositions:
case KVEncoding::kInt8TwoTranspositions:
return true;
default:
return false;
}
}
hwy::AlignedUniquePtr<char[]> AllocateEncodedTile(KVEncoding encoding,
size_t qkv_dim) {
std::optional<size_t> size = GetTileSizeBytes(encoding, qkv_dim);
if (!size.has_value()) return hwy::AlignedUniquePtr<char[]>();
return hwy::MakeUniqueAlignedArray<char>(*size);
}
bool DecodeTile(KVEncoding encoding, hwy::Span<const char> encoded_tile_data,
size_t qkv_dim, DecodedTile* out) {
std::optional<size_t> required_size_or = GetTileSizeBytes(encoding, qkv_dim);
if (!required_size_or.has_value()) return false;
size_t required_size = *required_size_or;
if (encoded_tile_data.size() < required_size) {
return false;
}
bool transposed = IsTransposed(encoding);
switch (encoding) {
case gcpp::KVEncoding::kF32:
case gcpp::KVEncoding::kF32TwoTranspositions: {
DecodeTileF32(transposed, qkv_dim, encoded_tile_data, out);
return true;
}
case gcpp::KVEncoding::kBF16:
case gcpp::KVEncoding::kBF16TwoTranspositions: {
DecodeTileBF16(transposed, qkv_dim, encoded_tile_data, out);
return true;
}
case gcpp::KVEncoding::kInt8:
case gcpp::KVEncoding::kInt8TwoTranspositions: {
DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out);
return true;
}
default:
return false;
}
}
bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded,
size_t qkv_dim, hwy::Span<char> out_encoded_tile_data) {
std::optional<size_t> required_size_or = GetTileSizeBytes(encoding, qkv_dim);
if (!required_size_or.has_value()) return false;
size_t required_size = *required_size_or;
if (out_encoded_tile_data.size() < required_size) {
return false;
}
bool transposed = IsTransposed(encoding);
switch (encoding) {
case gcpp::KVEncoding::kF32:
case gcpp::KVEncoding::kF32TwoTranspositions: {
EncodeTileF32(transposed, qkv_dim, decoded, out_encoded_tile_data);
return true;
}
case gcpp::KVEncoding::kBF16:
case gcpp::KVEncoding::kBF16TwoTranspositions: {
EncodeTileBF16(transposed, qkv_dim, decoded, out_encoded_tile_data);
return true;
}
case gcpp::KVEncoding::kInt8:
case gcpp::KVEncoding::kInt8TwoTranspositions: {
EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data);
return true;
}
default:
return false;
}
}
bool TranscodeTile(gcpp::KVEncoding src_encoding,
hwy::Span<const char> src_data,
gcpp::KVEncoding dst_encoding, hwy::Span<char> dst_data,
size_t qkv_dim) {
DecodedTile decoded(qkv_dim, kTileSize);
if (!DecodeTile(src_encoding, src_data, qkv_dim, &decoded)) return false;
return EncodeTile(dst_encoding, decoded, qkv_dim, dst_data);
}
} // namespace gcpp

70
gemma/kv_transcoding.h Normal file
View File

@ -0,0 +1,70 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_
#include <cstddef>
#include <optional>
#include <vector>
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
// Returns the size in bytes of a single KV cache tile for a given encoding.
// Returns std::nullopt if the encoding is unsupported.
std::optional<size_t> GetTileSizeBytes(gcpp::KVEncoding encoding,
size_t qkv_dim);
// Canonical representation of a single tile of K and V data decoded to float32.
// Layout: K is [tile_size, qkv_dim] contiguous, V is [tile_size, qkv_dim]
// contiguous.
struct DecodedTile {
std::vector<float, hwy::AlignedAllocator<float>> k;
std::vector<float, hwy::AlignedAllocator<float>> v;
size_t qkv_dim = 0;
size_t tile_size = 0;
DecodedTile() = default;
DecodedTile(size_t qkv_dim, size_t tile_size)
: k(qkv_dim * tile_size),
v(tile_size * qkv_dim),
qkv_dim(qkv_dim),
tile_size(tile_size) {}
float& k_elem(size_t token, size_t dim) { return k[token * qkv_dim + dim]; }
const float& k_elem(size_t token, size_t dim) const {
return k[token * qkv_dim + dim];
}
float& v_elem(size_t token, size_t dim) { return v[token * qkv_dim + dim]; }
const float& v_elem(size_t token, size_t dim) const {
return v[token * qkv_dim + dim];
}
};
// Allocates an aligned buffer for storing
// an encoded tile of the given encoding.
hwy::AlignedUniquePtr<char[]> AllocateEncodedTile(gcpp::KVEncoding encoding,
size_t qkv_dim);
// Decodes a single tile's K and V data from its encoded byte buffer into
// float32 using the specified encoding.
bool DecodeTile(gcpp::KVEncoding encoding,
hwy::Span<const char> encoded_tile_data, size_t qkv_dim,
DecodedTile* out);
// Encodes a single tile's K and V data from standard float32 into the target
// encoding. Returns false if the encoding is unsupported.
bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded,
size_t qkv_dim, hwy::Span<char> out_encoded_tile_data);
// Convenience utility to convert a tile directly from one encoding to another.
// Return false if either encoding is unsupported or passed data is too small.
bool TranscodeTile(gcpp::KVEncoding src_encoding,
hwy::Span<const char> src_data,
gcpp::KVEncoding dst_encoding, hwy::Span<char> dst_data,
size_t qkv_dim);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_

View File

@ -0,0 +1,360 @@
#include "gemma/kv_transcoding.h"
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // For hwy::Span
namespace gcpp {
namespace {
using ::testing::FloatNear;
using ::testing::Pointwise;
using ::testing::TestWithParam;
using ::testing::Values;
struct EncodingTestCase {
gcpp::KVEncoding encoding;
float tolerance;
};
class KVEncodingTest : public TestWithParam<EncodingTestCase> {};
TEST_P(KVEncodingTest, EncodeDecodeRoundTrip) {
const auto& param = GetParam();
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 256;
DecodedTile original(qkv_dim, kTileSize);
// Fill with dummy data within
// a reasonable float range to avoid saturation for INT8
const float pattern[] = {0.5f, 1.0f, 1.5f};
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
size_t i = dim * kTileSize + token;
original.k_elem(token, dim) = pattern[i % 3];
original.v_elem(token, dim) = pattern[i % 3];
}
}
std::optional<size_t> tile_size_bytes =
GetTileSizeBytes(param.encoding, qkv_dim);
ASSERT_TRUE(tile_size_bytes.has_value());
std::vector<char> encoded(*tile_size_bytes, 0);
EXPECT_TRUE(EncodeTile(param.encoding, original, qkv_dim,
hwy::Span<char>(encoded.data(), encoded.size())));
DecodedTile decoded(qkv_dim, kTileSize);
EXPECT_TRUE(DecodeTile(param.encoding,
hwy::Span<const char>(encoded.data(), encoded.size()),
qkv_dim, &decoded));
EXPECT_THAT(decoded.k, Pointwise(FloatNear(param.tolerance), original.k));
EXPECT_THAT(decoded.v, Pointwise(FloatNear(param.tolerance), original.v));
}
TEST_P(KVEncodingTest, SizeChecks) {
const auto& param = GetParam();
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 256;
DecodedTile decoded(qkv_dim, kTileSize);
std::optional<size_t> required_size_or =
GetTileSizeBytes(param.encoding, qkv_dim);
ASSERT_TRUE(required_size_or.has_value());
size_t required_size = *required_size_or;
if (required_size > 0) {
std::vector<char> too_small_encoded(required_size - 1, 0);
EXPECT_FALSE(EncodeTile(
param.encoding, decoded, qkv_dim,
hwy::Span<char>(too_small_encoded.data(), too_small_encoded.size())));
EXPECT_FALSE(DecodeTile(param.encoding,
hwy::Span<const char>(too_small_encoded.data(),
too_small_encoded.size()),
qkv_dim, &decoded));
}
}
INSTANTIATE_TEST_SUITE_P(
AllEncodings, KVEncodingTest,
Values(EncodingTestCase{gcpp::KVEncoding::kF32, 1e-6f},
EncodingTestCase{gcpp::KVEncoding::kF32TwoTranspositions, 1e-6f},
EncodingTestCase{gcpp::KVEncoding::kBF16, 0.05f},
EncodingTestCase{gcpp::KVEncoding::kBF16TwoTranspositions, 0.05f},
EncodingTestCase{gcpp::KVEncoding::kInt8, 0.1f},
EncodingTestCase{gcpp::KVEncoding::kInt8TwoTranspositions, 0.1f}));
TEST(KVEncodingTest, ConvertTileFloat32ToBfloat16) {
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 256;
gcpp::KVEncoding src_encoding = gcpp::KVEncoding::kF32;
gcpp::KVEncoding dst_encoding = gcpp::KVEncoding::kBF16;
DecodedTile original(qkv_dim, kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
size_t i = dim * kTileSize + token;
original.k_elem(token, dim) = std::sin(i) * 5.0f;
original.v_elem(token, dim) = std::cos(i) * 5.0f;
}
}
size_t src_size = GetTileSizeBytes(src_encoding, qkv_dim).value();
size_t dst_size = GetTileSizeBytes(dst_encoding, qkv_dim).value();
std::vector<char> src_data(src_size);
std::vector<char> dst_data(dst_size);
EXPECT_TRUE(EncodeTile(src_encoding, original, qkv_dim,
hwy::Span<char>(src_data.data(), src_data.size())));
EXPECT_TRUE(TranscodeTile(
src_encoding, hwy::Span<const char>(src_data.data(), src_data.size()),
dst_encoding, hwy::Span<char>(dst_data.data(), dst_data.size()),
qkv_dim));
DecodedTile decoded(qkv_dim, kTileSize);
EXPECT_TRUE(DecodeTile(
dst_encoding, hwy::Span<const char>(dst_data.data(), dst_data.size()),
qkv_dim, &decoded));
EXPECT_THAT(decoded.k, Pointwise(FloatNear(0.05f), original.k));
}
TEST(KVEncodingTest, PairwiseConversion) {
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 256;
std::vector<gcpp::KVEncoding> encodings = {
gcpp::KVEncoding::kF32, gcpp::KVEncoding::kF32TwoTranspositions,
gcpp::KVEncoding::kBF16, gcpp::KVEncoding::kBF16TwoTranspositions,
gcpp::KVEncoding::kInt8, gcpp::KVEncoding::kInt8TwoTranspositions};
for (auto src : encodings) {
for (auto dst : encodings) {
if (src == dst) continue;
DecodedTile original(qkv_dim, kTileSize);
const float pattern[] = {0.5f, 1.0f, 1.5f};
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
size_t i = dim * kTileSize + token;
original.k_elem(token, dim) = pattern[i % 3];
original.v_elem(token, dim) = pattern[i % 3];
}
}
size_t src_size = GetTileSizeBytes(src, qkv_dim).value();
size_t dst_size = GetTileSizeBytes(dst, qkv_dim).value();
std::vector<char> src_data(src_size);
std::vector<char> dst_data(dst_size);
ASSERT_TRUE(EncodeTile(src, original, qkv_dim,
hwy::Span<char>(src_data.data(), src_data.size())))
<< "src=" << static_cast<int>(src);
ASSERT_TRUE(TranscodeTile(
src, hwy::Span<const char>(src_data.data(), src_data.size()), dst,
hwy::Span<char>(dst_data.data(), dst_data.size()), qkv_dim))
<< "src=" << static_cast<int>(src)
<< " dst=" << static_cast<int>(dst);
DecodedTile decoded(qkv_dim, kTileSize);
ASSERT_TRUE(DecodeTile(
dst, hwy::Span<const char>(dst_data.data(), dst_data.size()), qkv_dim,
&decoded))
<< "dst=" << static_cast<int>(dst);
float tolerance = 0.1f; // Max tolerance for Int8
EXPECT_THAT(decoded.k, Pointwise(FloatNear(tolerance), original.k))
<< "src=" << static_cast<int>(src)
<< " dst=" << static_cast<int>(dst);
EXPECT_THAT(decoded.v, Pointwise(FloatNear(tolerance), original.v))
<< "src=" << static_cast<int>(src)
<< " dst=" << static_cast<int>(dst);
}
}
}
TEST(KVEncodingTest, LayoutValidationF32) {
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 4;
gcpp::KVEncoding encoding = gcpp::KVEncoding::kF32;
DecodedTile original(qkv_dim, kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.k_elem(token, dim) = dim * kTileSize + token + 1;
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.v_elem(token, dim) =
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
}
}
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
std::vector<char> encoded(size);
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
hwy::Span<char>(encoded.data(), encoded.size())));
const float* data = reinterpret_cast<const float*>(encoded.data());
// K should be row-major [qkv_dim, tile_size]
EXPECT_EQ(data[0], 1.0f); // d=0, t=0
EXPECT_EQ(data[1], 2.0f); // d=0, t=1
EXPECT_EQ(data[32], 33.0f); // d=1, t=0
// V should be row-major [tile_size, qkv_dim]
size_t v_start = qkv_dim * kTileSize;
EXPECT_EQ(data[v_start], 129.0f); // t=0, d=0
EXPECT_EQ(data[v_start + 1], 130.0f); // t=0, d=1
EXPECT_EQ(data[v_start + 4], 133.0f); // t=1, d=0
}
TEST(KVEncodingTest, LayoutValidationF32TwoTranspositions) {
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 4;
gcpp::KVEncoding encoding = gcpp::KVEncoding::kF32TwoTranspositions;
DecodedTile original(qkv_dim, kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.k_elem(token, dim) = dim * kTileSize + token + 1;
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.v_elem(token, dim) =
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
}
}
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
std::vector<char> encoded(size);
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
hwy::Span<char>(encoded.data(), encoded.size())));
const float* data = reinterpret_cast<const float*>(encoded.data());
// K transposed: [qkv_dim/2, tile_size, 2]
EXPECT_EQ(data[0], 1.0f); // d=0, t=0
EXPECT_EQ(data[1], 33.0f); // d=1, t=0
EXPECT_EQ(data[2], 2.0f); // d=0, t=1
EXPECT_EQ(data[3], 34.0f); // d=1, t=1
EXPECT_EQ(data[64], 65.0f); // d=2, t=0
EXPECT_EQ(data[65], 97.0f); // d=3, t=0
// V transposed: [tile_size/2, qkv_dim, 2]
size_t v_start = qkv_dim * kTileSize;
EXPECT_EQ(data[v_start], 129.0f); // t=0, d=0
EXPECT_EQ(data[v_start + 1], 133.0f); // t=1, d=0
EXPECT_EQ(data[v_start + 2], 130.0f); // t=0, d=1
EXPECT_EQ(data[v_start + 3], 134.0f); // t=1, d=1
}
TEST(KVEncodingTest, LayoutValidationInt8) {
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 4;
gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8;
DecodedTile original(qkv_dim, kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.k_elem(token, dim) = dim * kTileSize + token + 1;
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.v_elem(token, dim) =
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
}
}
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
std::vector<char> encoded(size);
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
hwy::Span<char>(encoded.data(), encoded.size())));
const int8_t* data = reinterpret_cast<const int8_t*>(encoded.data());
// K should be row-major [qkv_dim, tile_size]
// K[3,0] = 97. Max for t=0 is 97. Scale = 97/127.
// Quantized K[3,0] = 127.
// K[3,0] is at offset 3 * 32 + 0 = 96.
EXPECT_EQ(data[96], 127);
// V should be row-major [tile_size, qkv_dim]
size_t v_start = qkv_dim * kTileSize;
// V[0,3] = 132. Max for t=0 is 132. Scale = 132/127.
// Quantized V[0,3] = 127.
// V[0,3] is at offset v_start + 0 * 4 + 3 = v_start + 3.
EXPECT_EQ(data[v_start + 3], 127);
}
TEST(KVEncodingTest, LayoutValidationInt8TwoTranspositions) {
constexpr size_t kTileSize = 32;
constexpr size_t qkv_dim = 4;
gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8TwoTranspositions;
DecodedTile original(qkv_dim, kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.k_elem(token, dim) = dim * kTileSize + token + 1;
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
original.v_elem(token, dim) =
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
}
}
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
std::vector<char> encoded(size);
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
hwy::Span<char>(encoded.data(), encoded.size())));
const int8_t* data = reinterpret_cast<const int8_t*>(encoded.data());
// K transposed: [qkv_dim/2, tile_size, 2]
// K[0,0] = 1. Max for t=0 is 97. Scale = 97/127.
// Quantized K[0,0] = 1.
// K[1,0] = 33. Quantized K[1,0] = 33 / (97/127) = 43.14 -> 43.
// K[1,0] is at offset 1.
EXPECT_EQ(data[0], 1);
EXPECT_EQ(data[1], 43);
// V transposed: [tile_size/2, qkv_dim, 2]
size_t v_start = qkv_dim * kTileSize;
// V[0,0] = 129. Max for t=0 is 132. Scale = 132/127.
// Quantized V[0,0] = round(129 * 127 / 132) = 124.
// V[1,0] = 133. Max for t=1 is 136. Scale = 136/127.
// Quantized V[1,0] = round(133 * 127 / 136) = 124.
// In transposed layout, V[0,0] is at v_start. V[1,0] is at v_start + 1.
EXPECT_EQ(data[v_start], 124);
EXPECT_EQ(data[v_start + 1], 124);
// V[1,3] = 136. Max for t=1 is 136. Quantized = 127.
// Offset in transposed V: t/2*8 + d*2 + t%2.
// For t=1, d=3: 0*8 + 3*2 + 1 = 7.
EXPECT_EQ(data[v_start + 7], 127);
}
} // namespace
} // namespace gcpp