mirror of https://github.com/google/gemma.cpp.git
Utilities to convert between different encodings of kv cache
PiperOrigin-RevId: 885553004
This commit is contained in:
parent
0110ddfee7
commit
1a5226e5de
25
BUILD.bazel
25
BUILD.bazel
|
|
@ -542,6 +542,31 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kv_transcoding",
|
||||
srcs = ["gemma/kv_transcoding.cc"],
|
||||
hdrs = ["gemma/kv_transcoding.h"],
|
||||
deps = [
|
||||
":activations",
|
||||
":basics",
|
||||
":configs",
|
||||
":kv_cache",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "kv_transcoding_test",
|
||||
srcs = ["gemma/kv_transcoding_test.cc"],
|
||||
deps = [
|
||||
":configs",
|
||||
":kv_transcoding",
|
||||
"//testing/base/public:gunit_main",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "activations",
|
||||
hdrs = ["gemma/activations.h"],
|
||||
|
|
|
|||
|
|
@ -83,6 +83,17 @@ static inline bool EnumValid(LayerAttentionType type) {
|
|||
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
|
||||
}
|
||||
|
||||
// Values stated explicitly to allow for semantic reordering
|
||||
enum class KVEncoding {
|
||||
kUnspecified = 0,
|
||||
kF32 = 1,
|
||||
kBF16 = 2,
|
||||
kF32TwoTranspositions = 3,
|
||||
kBF16TwoTranspositions = 4,
|
||||
kInt8 = 5,
|
||||
kInt8TwoTranspositions = 6,
|
||||
};
|
||||
|
||||
enum class AttentionImpl {
|
||||
kOld, // Previous Attention implementation
|
||||
kFlash, // Flash Attention (default)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,314 @@
|
|||
#include "gemma/kv_transcoding.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <optional>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
std::optional<size_t> GetTileSizeBytes(gcpp::KVEncoding encoding,
|
||||
size_t qkv_dim) {
|
||||
constexpr size_t kTileSize = gcpp::KVCache::kTileSize;
|
||||
switch (encoding) {
|
||||
case gcpp::KVEncoding::kInt8:
|
||||
case gcpp::KVEncoding::kInt8TwoTranspositions:
|
||||
return qkv_dim * kTileSize * 2 * sizeof(int8_t) +
|
||||
kTileSize * 2 * sizeof(gcpp::KV_microscale_t);
|
||||
case gcpp::KVEncoding::kBF16:
|
||||
case gcpp::KVEncoding::kBF16TwoTranspositions:
|
||||
return qkv_dim * kTileSize * 2 * sizeof(gcpp::BF16);
|
||||
case gcpp::KVEncoding::kF32:
|
||||
case gcpp::KVEncoding::kF32TwoTranspositions:
|
||||
return qkv_dim * kTileSize * 2 * sizeof(float);
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr size_t kTileSize = gcpp::KVCache::kTileSize;
|
||||
|
||||
inline size_t KOffset(bool transposed, size_t qkv_dim, size_t dim,
|
||||
size_t token) {
|
||||
HWY_DASSERT(dim < qkv_dim && token < kTileSize);
|
||||
return transposed ? ((dim / 2) * kTileSize * 2 + token * 2 + (dim % 2))
|
||||
: (dim * kTileSize + token);
|
||||
}
|
||||
|
||||
inline size_t VOffset(bool transposed, size_t qkv_dim, size_t dim,
|
||||
size_t token) {
|
||||
HWY_DASSERT(dim < qkv_dim && token < kTileSize);
|
||||
return transposed ? ((token / 2) * qkv_dim * 2 + dim * 2 + (token % 2))
|
||||
: (token * qkv_dim + dim);
|
||||
}
|
||||
|
||||
int8_t Quantize(float v, float inv_scale) {
|
||||
float scaled = v * inv_scale;
|
||||
if (scaled > 127.0f) return 127;
|
||||
if (scaled < -127.0f) return -127;
|
||||
return hwy::ConvertScalarTo<int8_t>(scaled);
|
||||
}
|
||||
|
||||
template <typename DecodeKFn, typename DecodeVFn>
|
||||
inline void DecodeTileWithFn(size_t qkv_dim, DecodedTile* out,
|
||||
const DecodeKFn& decode_k,
|
||||
const DecodeVFn& decode_v) {
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
out->k_elem(token, dim) = decode_k(dim, token);
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
out->v_elem(token, dim) = decode_v(dim, token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EncodeKFn, typename EncodeVFn>
|
||||
inline void EncodeTileWithFn(size_t qkv_dim, const DecodedTile& decoded,
|
||||
const EncodeKFn& encode_k,
|
||||
const EncodeVFn& encode_v) {
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
encode_k(dim, token, decoded.k_elem(token, dim));
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
encode_v(dim, token, decoded.v_elem(token, dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EncodeTileF32(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
|
||||
hwy::Span<char> out_encoded_tile_data) {
|
||||
float* data = HWY_RCAST_ALIGNED(float*, out_encoded_tile_data.data());
|
||||
const size_t v_start = qkv_dim * kTileSize;
|
||||
EncodeTileWithFn(
|
||||
qkv_dim, decoded,
|
||||
[&](size_t dim, size_t token, float val)
|
||||
HWY_ATTR { data[KOffset(transposed, qkv_dim, dim, token)] = val; },
|
||||
[&](size_t dim, size_t token, float val) HWY_ATTR {
|
||||
data[v_start + VOffset(transposed, qkv_dim, dim, token)] = val;
|
||||
});
|
||||
}
|
||||
|
||||
void EncodeTileBF16(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
|
||||
hwy::Span<char> out_encoded_tile_data) {
|
||||
gcpp::BF16* data =
|
||||
HWY_RCAST_ALIGNED(gcpp::BF16*, out_encoded_tile_data.data());
|
||||
const size_t v_start = qkv_dim * kTileSize;
|
||||
EncodeTileWithFn(
|
||||
qkv_dim, decoded,
|
||||
[&](size_t dim, size_t token, float val) HWY_ATTR {
|
||||
data[KOffset(transposed, qkv_dim, dim, token)] =
|
||||
hwy::ConvertScalarTo<hwy::bfloat16_t>(val);
|
||||
},
|
||||
[&](size_t dim, size_t token, float val) HWY_ATTR {
|
||||
data[v_start + VOffset(transposed, qkv_dim, dim, token)] =
|
||||
hwy::ConvertScalarTo<hwy::bfloat16_t>(val);
|
||||
});
|
||||
}
|
||||
|
||||
void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
|
||||
hwy::Span<char> out_encoded_tile_data) {
|
||||
int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data());
|
||||
int8_t* v_data = k_data + qkv_dim * kTileSize;
|
||||
gcpp::KV_microscale_t* scales =
|
||||
HWY_RCAST_ALIGNED(gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim);
|
||||
gcpp::KV_microscale_t* k_scales = scales;
|
||||
gcpp::KV_microscale_t* v_scales = scales + kTileSize;
|
||||
|
||||
AlignedFloatVector k_max_abs(kTileSize, 0.0f);
|
||||
AlignedFloatVector v_max_abs(kTileSize, 0.0f);
|
||||
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
k_max_abs[token] =
|
||||
std::max(k_max_abs[token], std::abs(decoded.k_elem(token, dim)));
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
v_max_abs[token] =
|
||||
std::max(v_max_abs[token], std::abs(decoded.v_elem(token, dim)));
|
||||
}
|
||||
}
|
||||
|
||||
AlignedFloatVector inv_scales_k(kTileSize);
|
||||
AlignedFloatVector inv_scales_v(kTileSize);
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
float scale_k = k_max_abs[token] == 0.0f ? 1.0f : k_max_abs[token] / 127.0f;
|
||||
k_scales[token] = hwy::ConvertScalarTo<gcpp::KV_microscale_t>(scale_k);
|
||||
inv_scales_k[token] = 1.0f / scale_k;
|
||||
|
||||
float scale_v = v_max_abs[token] == 0.0f ? 1.0f : v_max_abs[token] / 127.0f;
|
||||
v_scales[token] = hwy::ConvertScalarTo<gcpp::KV_microscale_t>(scale_v);
|
||||
inv_scales_v[token] = 1.0f / scale_v;
|
||||
}
|
||||
|
||||
EncodeTileWithFn(
|
||||
qkv_dim, decoded,
|
||||
[&](size_t dim, size_t token, float val) HWY_ATTR {
|
||||
k_data[KOffset(transposed, qkv_dim, dim, token)] =
|
||||
Quantize(val, inv_scales_k[token]);
|
||||
},
|
||||
[&](size_t dim, size_t token, float val) HWY_ATTR {
|
||||
v_data[VOffset(transposed, qkv_dim, dim, token)] =
|
||||
Quantize(val, inv_scales_v[token]);
|
||||
});
|
||||
}
|
||||
|
||||
void DecodeTileF32(bool transposed, size_t qkv_dim,
|
||||
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
|
||||
const float* data = HWY_RCAST_ALIGNED(const float*, encoded_tile_data.data());
|
||||
const size_t v_start = qkv_dim * kTileSize;
|
||||
DecodeTileWithFn(
|
||||
qkv_dim, out,
|
||||
[&](size_t dim, size_t token)
|
||||
HWY_ATTR { return data[KOffset(transposed, qkv_dim, dim, token)]; },
|
||||
[&](size_t dim, size_t token) HWY_ATTR {
|
||||
return data[v_start + VOffset(transposed, qkv_dim, dim, token)];
|
||||
});
|
||||
}
|
||||
|
||||
void DecodeTileBF16(bool transposed, size_t qkv_dim,
|
||||
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
|
||||
const gcpp::BF16* data =
|
||||
HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded_tile_data.data());
|
||||
const size_t v_start = qkv_dim * kTileSize;
|
||||
DecodeTileWithFn(
|
||||
qkv_dim, out,
|
||||
[&](size_t dim, size_t token) HWY_ATTR {
|
||||
return hwy::ConvertScalarTo<float>(
|
||||
data[KOffset(transposed, qkv_dim, dim, token)]);
|
||||
},
|
||||
[&](size_t dim, size_t token) HWY_ATTR {
|
||||
return hwy::ConvertScalarTo<float>(
|
||||
data[v_start + VOffset(transposed, qkv_dim, dim, token)]);
|
||||
});
|
||||
}
|
||||
|
||||
void DecodeTileInt8(bool transposed, size_t qkv_dim,
|
||||
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
|
||||
const int8_t* k_data =
|
||||
HWY_RCAST_ALIGNED(const int8_t*, encoded_tile_data.data());
|
||||
const int8_t* v_data = k_data + qkv_dim * kTileSize;
|
||||
const gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED(
|
||||
const gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim);
|
||||
const gcpp::KV_microscale_t* k_scales = scales;
|
||||
const gcpp::KV_microscale_t* v_scales = scales + kTileSize;
|
||||
|
||||
DecodeTileWithFn(
|
||||
qkv_dim, out,
|
||||
[&](size_t dim, size_t token) HWY_ATTR {
|
||||
float scale = hwy::ConvertScalarTo<float>(k_scales[token]);
|
||||
return k_data[KOffset(transposed, qkv_dim, dim, token)] * scale;
|
||||
},
|
||||
[&](size_t dim, size_t token) HWY_ATTR {
|
||||
float scale = hwy::ConvertScalarTo<float>(v_scales[token]);
|
||||
return v_data[VOffset(transposed, qkv_dim, dim, token)] * scale;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool IsTransposed(KVEncoding encoding) {
|
||||
switch (encoding) {
|
||||
case KVEncoding::kF32TwoTranspositions:
|
||||
case KVEncoding::kBF16TwoTranspositions:
|
||||
case KVEncoding::kInt8TwoTranspositions:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
hwy::AlignedUniquePtr<char[]> AllocateEncodedTile(KVEncoding encoding,
|
||||
size_t qkv_dim) {
|
||||
std::optional<size_t> size = GetTileSizeBytes(encoding, qkv_dim);
|
||||
if (!size.has_value()) return hwy::AlignedUniquePtr<char[]>();
|
||||
return hwy::MakeUniqueAlignedArray<char>(*size);
|
||||
}
|
||||
|
||||
bool DecodeTile(KVEncoding encoding, hwy::Span<const char> encoded_tile_data,
|
||||
size_t qkv_dim, DecodedTile* out) {
|
||||
std::optional<size_t> required_size_or = GetTileSizeBytes(encoding, qkv_dim);
|
||||
if (!required_size_or.has_value()) return false;
|
||||
size_t required_size = *required_size_or;
|
||||
if (encoded_tile_data.size() < required_size) {
|
||||
return false;
|
||||
}
|
||||
bool transposed = IsTransposed(encoding);
|
||||
switch (encoding) {
|
||||
case gcpp::KVEncoding::kF32:
|
||||
case gcpp::KVEncoding::kF32TwoTranspositions: {
|
||||
DecodeTileF32(transposed, qkv_dim, encoded_tile_data, out);
|
||||
return true;
|
||||
}
|
||||
case gcpp::KVEncoding::kBF16:
|
||||
case gcpp::KVEncoding::kBF16TwoTranspositions: {
|
||||
DecodeTileBF16(transposed, qkv_dim, encoded_tile_data, out);
|
||||
return true;
|
||||
}
|
||||
case gcpp::KVEncoding::kInt8:
|
||||
case gcpp::KVEncoding::kInt8TwoTranspositions: {
|
||||
DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded,
|
||||
size_t qkv_dim, hwy::Span<char> out_encoded_tile_data) {
|
||||
std::optional<size_t> required_size_or = GetTileSizeBytes(encoding, qkv_dim);
|
||||
if (!required_size_or.has_value()) return false;
|
||||
size_t required_size = *required_size_or;
|
||||
if (out_encoded_tile_data.size() < required_size) {
|
||||
return false;
|
||||
}
|
||||
bool transposed = IsTransposed(encoding);
|
||||
switch (encoding) {
|
||||
case gcpp::KVEncoding::kF32:
|
||||
case gcpp::KVEncoding::kF32TwoTranspositions: {
|
||||
EncodeTileF32(transposed, qkv_dim, decoded, out_encoded_tile_data);
|
||||
return true;
|
||||
}
|
||||
case gcpp::KVEncoding::kBF16:
|
||||
case gcpp::KVEncoding::kBF16TwoTranspositions: {
|
||||
EncodeTileBF16(transposed, qkv_dim, decoded, out_encoded_tile_data);
|
||||
return true;
|
||||
}
|
||||
case gcpp::KVEncoding::kInt8:
|
||||
case gcpp::KVEncoding::kInt8TwoTranspositions: {
|
||||
EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool TranscodeTile(gcpp::KVEncoding src_encoding,
|
||||
hwy::Span<const char> src_data,
|
||||
gcpp::KVEncoding dst_encoding, hwy::Span<char> dst_data,
|
||||
size_t qkv_dim) {
|
||||
DecodedTile decoded(qkv_dim, kTileSize);
|
||||
if (!DecodeTile(src_encoding, src_data, qkv_dim, &decoded)) return false;
|
||||
|
||||
return EncodeTile(dst_encoding, decoded, qkv_dim, dst_data);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Returns the size in bytes of a single KV cache tile for a given encoding.
|
||||
// Returns std::nullopt if the encoding is unsupported.
|
||||
std::optional<size_t> GetTileSizeBytes(gcpp::KVEncoding encoding,
|
||||
size_t qkv_dim);
|
||||
|
||||
// Canonical representation of a single tile of K and V data decoded to float32.
|
||||
// Layout: K is [tile_size, qkv_dim] contiguous, V is [tile_size, qkv_dim]
|
||||
// contiguous.
|
||||
struct DecodedTile {
|
||||
std::vector<float, hwy::AlignedAllocator<float>> k;
|
||||
std::vector<float, hwy::AlignedAllocator<float>> v;
|
||||
size_t qkv_dim = 0;
|
||||
size_t tile_size = 0;
|
||||
|
||||
DecodedTile() = default;
|
||||
DecodedTile(size_t qkv_dim, size_t tile_size)
|
||||
: k(qkv_dim * tile_size),
|
||||
v(tile_size * qkv_dim),
|
||||
qkv_dim(qkv_dim),
|
||||
tile_size(tile_size) {}
|
||||
|
||||
float& k_elem(size_t token, size_t dim) { return k[token * qkv_dim + dim]; }
|
||||
const float& k_elem(size_t token, size_t dim) const {
|
||||
return k[token * qkv_dim + dim];
|
||||
}
|
||||
|
||||
float& v_elem(size_t token, size_t dim) { return v[token * qkv_dim + dim]; }
|
||||
const float& v_elem(size_t token, size_t dim) const {
|
||||
return v[token * qkv_dim + dim];
|
||||
}
|
||||
};
|
||||
|
||||
// Allocates an aligned buffer for storing
|
||||
// an encoded tile of the given encoding.
|
||||
hwy::AlignedUniquePtr<char[]> AllocateEncodedTile(gcpp::KVEncoding encoding,
|
||||
size_t qkv_dim);
|
||||
|
||||
// Decodes a single tile's K and V data from its encoded byte buffer into
|
||||
// float32 using the specified encoding.
|
||||
bool DecodeTile(gcpp::KVEncoding encoding,
|
||||
hwy::Span<const char> encoded_tile_data, size_t qkv_dim,
|
||||
DecodedTile* out);
|
||||
|
||||
// Encodes a single tile's K and V data from standard float32 into the target
|
||||
// encoding. Returns false if the encoding is unsupported.
|
||||
bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded,
|
||||
size_t qkv_dim, hwy::Span<char> out_encoded_tile_data);
|
||||
|
||||
// Convenience utility to convert a tile directly from one encoding to another.
|
||||
// Return false if either encoding is unsupported or passed data is too small.
|
||||
bool TranscodeTile(gcpp::KVEncoding src_encoding,
|
||||
hwy::Span<const char> src_data,
|
||||
gcpp::KVEncoding dst_encoding, hwy::Span<char> dst_data,
|
||||
size_t qkv_dim);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
#include "gemma/kv_transcoding.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // For hwy::Span
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
struct EncodingTestCase {
|
||||
gcpp::KVEncoding encoding;
|
||||
float tolerance;
|
||||
};
|
||||
|
||||
class KVEncodingTest : public TestWithParam<EncodingTestCase> {};
|
||||
|
||||
TEST_P(KVEncodingTest, EncodeDecodeRoundTrip) {
|
||||
const auto& param = GetParam();
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 256;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
// Fill with dummy data within
|
||||
// a reasonable float range to avoid saturation for INT8
|
||||
const float pattern[] = {0.5f, 1.0f, 1.5f};
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
size_t i = dim * kTileSize + token;
|
||||
original.k_elem(token, dim) = pattern[i % 3];
|
||||
original.v_elem(token, dim) = pattern[i % 3];
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<size_t> tile_size_bytes =
|
||||
GetTileSizeBytes(param.encoding, qkv_dim);
|
||||
ASSERT_TRUE(tile_size_bytes.has_value());
|
||||
|
||||
std::vector<char> encoded(*tile_size_bytes, 0);
|
||||
EXPECT_TRUE(EncodeTile(param.encoding, original, qkv_dim,
|
||||
hwy::Span<char>(encoded.data(), encoded.size())));
|
||||
|
||||
DecodedTile decoded(qkv_dim, kTileSize);
|
||||
EXPECT_TRUE(DecodeTile(param.encoding,
|
||||
hwy::Span<const char>(encoded.data(), encoded.size()),
|
||||
qkv_dim, &decoded));
|
||||
|
||||
EXPECT_THAT(decoded.k, Pointwise(FloatNear(param.tolerance), original.k));
|
||||
EXPECT_THAT(decoded.v, Pointwise(FloatNear(param.tolerance), original.v));
|
||||
}
|
||||
|
||||
TEST_P(KVEncodingTest, SizeChecks) {
|
||||
const auto& param = GetParam();
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 256;
|
||||
|
||||
DecodedTile decoded(qkv_dim, kTileSize);
|
||||
std::optional<size_t> required_size_or =
|
||||
GetTileSizeBytes(param.encoding, qkv_dim);
|
||||
ASSERT_TRUE(required_size_or.has_value());
|
||||
size_t required_size = *required_size_or;
|
||||
|
||||
if (required_size > 0) {
|
||||
std::vector<char> too_small_encoded(required_size - 1, 0);
|
||||
EXPECT_FALSE(EncodeTile(
|
||||
param.encoding, decoded, qkv_dim,
|
||||
hwy::Span<char>(too_small_encoded.data(), too_small_encoded.size())));
|
||||
EXPECT_FALSE(DecodeTile(param.encoding,
|
||||
hwy::Span<const char>(too_small_encoded.data(),
|
||||
too_small_encoded.size()),
|
||||
qkv_dim, &decoded));
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
AllEncodings, KVEncodingTest,
|
||||
Values(EncodingTestCase{gcpp::KVEncoding::kF32, 1e-6f},
|
||||
EncodingTestCase{gcpp::KVEncoding::kF32TwoTranspositions, 1e-6f},
|
||||
EncodingTestCase{gcpp::KVEncoding::kBF16, 0.05f},
|
||||
EncodingTestCase{gcpp::KVEncoding::kBF16TwoTranspositions, 0.05f},
|
||||
EncodingTestCase{gcpp::KVEncoding::kInt8, 0.1f},
|
||||
EncodingTestCase{gcpp::KVEncoding::kInt8TwoTranspositions, 0.1f}));
|
||||
|
||||
TEST(KVEncodingTest, ConvertTileFloat32ToBfloat16) {
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 256;
|
||||
gcpp::KVEncoding src_encoding = gcpp::KVEncoding::kF32;
|
||||
gcpp::KVEncoding dst_encoding = gcpp::KVEncoding::kBF16;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
size_t i = dim * kTileSize + token;
|
||||
original.k_elem(token, dim) = std::sin(i) * 5.0f;
|
||||
original.v_elem(token, dim) = std::cos(i) * 5.0f;
|
||||
}
|
||||
}
|
||||
|
||||
size_t src_size = GetTileSizeBytes(src_encoding, qkv_dim).value();
|
||||
size_t dst_size = GetTileSizeBytes(dst_encoding, qkv_dim).value();
|
||||
|
||||
std::vector<char> src_data(src_size);
|
||||
std::vector<char> dst_data(dst_size);
|
||||
|
||||
EXPECT_TRUE(EncodeTile(src_encoding, original, qkv_dim,
|
||||
hwy::Span<char>(src_data.data(), src_data.size())));
|
||||
|
||||
EXPECT_TRUE(TranscodeTile(
|
||||
src_encoding, hwy::Span<const char>(src_data.data(), src_data.size()),
|
||||
dst_encoding, hwy::Span<char>(dst_data.data(), dst_data.size()),
|
||||
qkv_dim));
|
||||
|
||||
DecodedTile decoded(qkv_dim, kTileSize);
|
||||
EXPECT_TRUE(DecodeTile(
|
||||
dst_encoding, hwy::Span<const char>(dst_data.data(), dst_data.size()),
|
||||
qkv_dim, &decoded));
|
||||
|
||||
EXPECT_THAT(decoded.k, Pointwise(FloatNear(0.05f), original.k));
|
||||
}
|
||||
|
||||
TEST(KVEncodingTest, PairwiseConversion) {
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 256;
|
||||
|
||||
std::vector<gcpp::KVEncoding> encodings = {
|
||||
gcpp::KVEncoding::kF32, gcpp::KVEncoding::kF32TwoTranspositions,
|
||||
gcpp::KVEncoding::kBF16, gcpp::KVEncoding::kBF16TwoTranspositions,
|
||||
gcpp::KVEncoding::kInt8, gcpp::KVEncoding::kInt8TwoTranspositions};
|
||||
|
||||
for (auto src : encodings) {
|
||||
for (auto dst : encodings) {
|
||||
if (src == dst) continue;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
const float pattern[] = {0.5f, 1.0f, 1.5f};
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
size_t i = dim * kTileSize + token;
|
||||
original.k_elem(token, dim) = pattern[i % 3];
|
||||
original.v_elem(token, dim) = pattern[i % 3];
|
||||
}
|
||||
}
|
||||
|
||||
size_t src_size = GetTileSizeBytes(src, qkv_dim).value();
|
||||
size_t dst_size = GetTileSizeBytes(dst, qkv_dim).value();
|
||||
|
||||
std::vector<char> src_data(src_size);
|
||||
std::vector<char> dst_data(dst_size);
|
||||
|
||||
ASSERT_TRUE(EncodeTile(src, original, qkv_dim,
|
||||
hwy::Span<char>(src_data.data(), src_data.size())))
|
||||
<< "src=" << static_cast<int>(src);
|
||||
|
||||
ASSERT_TRUE(TranscodeTile(
|
||||
src, hwy::Span<const char>(src_data.data(), src_data.size()), dst,
|
||||
hwy::Span<char>(dst_data.data(), dst_data.size()), qkv_dim))
|
||||
<< "src=" << static_cast<int>(src)
|
||||
<< " dst=" << static_cast<int>(dst);
|
||||
|
||||
DecodedTile decoded(qkv_dim, kTileSize);
|
||||
ASSERT_TRUE(DecodeTile(
|
||||
dst, hwy::Span<const char>(dst_data.data(), dst_data.size()), qkv_dim,
|
||||
&decoded))
|
||||
<< "dst=" << static_cast<int>(dst);
|
||||
|
||||
float tolerance = 0.1f; // Max tolerance for Int8
|
||||
EXPECT_THAT(decoded.k, Pointwise(FloatNear(tolerance), original.k))
|
||||
<< "src=" << static_cast<int>(src)
|
||||
<< " dst=" << static_cast<int>(dst);
|
||||
EXPECT_THAT(decoded.v, Pointwise(FloatNear(tolerance), original.v))
|
||||
<< "src=" << static_cast<int>(src)
|
||||
<< " dst=" << static_cast<int>(dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(KVEncodingTest, LayoutValidationF32) {
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 4;
|
||||
gcpp::KVEncoding encoding = gcpp::KVEncoding::kF32;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.k_elem(token, dim) = dim * kTileSize + token + 1;
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.v_elem(token, dim) =
|
||||
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
|
||||
std::vector<char> encoded(size);
|
||||
|
||||
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
|
||||
hwy::Span<char>(encoded.data(), encoded.size())));
|
||||
|
||||
const float* data = reinterpret_cast<const float*>(encoded.data());
|
||||
|
||||
// K should be row-major [qkv_dim, tile_size]
|
||||
EXPECT_EQ(data[0], 1.0f); // d=0, t=0
|
||||
EXPECT_EQ(data[1], 2.0f); // d=0, t=1
|
||||
EXPECT_EQ(data[32], 33.0f); // d=1, t=0
|
||||
|
||||
// V should be row-major [tile_size, qkv_dim]
|
||||
size_t v_start = qkv_dim * kTileSize;
|
||||
EXPECT_EQ(data[v_start], 129.0f); // t=0, d=0
|
||||
EXPECT_EQ(data[v_start + 1], 130.0f); // t=0, d=1
|
||||
EXPECT_EQ(data[v_start + 4], 133.0f); // t=1, d=0
|
||||
}
|
||||
|
||||
TEST(KVEncodingTest, LayoutValidationF32TwoTranspositions) {
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 4;
|
||||
gcpp::KVEncoding encoding = gcpp::KVEncoding::kF32TwoTranspositions;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.k_elem(token, dim) = dim * kTileSize + token + 1;
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.v_elem(token, dim) =
|
||||
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
|
||||
std::vector<char> encoded(size);
|
||||
|
||||
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
|
||||
hwy::Span<char>(encoded.data(), encoded.size())));
|
||||
|
||||
const float* data = reinterpret_cast<const float*>(encoded.data());
|
||||
|
||||
// K transposed: [qkv_dim/2, tile_size, 2]
|
||||
EXPECT_EQ(data[0], 1.0f); // d=0, t=0
|
||||
EXPECT_EQ(data[1], 33.0f); // d=1, t=0
|
||||
EXPECT_EQ(data[2], 2.0f); // d=0, t=1
|
||||
EXPECT_EQ(data[3], 34.0f); // d=1, t=1
|
||||
EXPECT_EQ(data[64], 65.0f); // d=2, t=0
|
||||
EXPECT_EQ(data[65], 97.0f); // d=3, t=0
|
||||
|
||||
// V transposed: [tile_size/2, qkv_dim, 2]
|
||||
size_t v_start = qkv_dim * kTileSize;
|
||||
EXPECT_EQ(data[v_start], 129.0f); // t=0, d=0
|
||||
EXPECT_EQ(data[v_start + 1], 133.0f); // t=1, d=0
|
||||
EXPECT_EQ(data[v_start + 2], 130.0f); // t=0, d=1
|
||||
EXPECT_EQ(data[v_start + 3], 134.0f); // t=1, d=1
|
||||
}
|
||||
|
||||
TEST(KVEncodingTest, LayoutValidationInt8) {
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 4;
|
||||
gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.k_elem(token, dim) = dim * kTileSize + token + 1;
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.v_elem(token, dim) =
|
||||
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
|
||||
std::vector<char> encoded(size);
|
||||
|
||||
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
|
||||
hwy::Span<char>(encoded.data(), encoded.size())));
|
||||
|
||||
const int8_t* data = reinterpret_cast<const int8_t*>(encoded.data());
|
||||
|
||||
// K should be row-major [qkv_dim, tile_size]
|
||||
// K[3,0] = 97. Max for t=0 is 97. Scale = 97/127.
|
||||
// Quantized K[3,0] = 127.
|
||||
// K[3,0] is at offset 3 * 32 + 0 = 96.
|
||||
EXPECT_EQ(data[96], 127);
|
||||
|
||||
// V should be row-major [tile_size, qkv_dim]
|
||||
size_t v_start = qkv_dim * kTileSize;
|
||||
// V[0,3] = 132. Max for t=0 is 132. Scale = 132/127.
|
||||
// Quantized V[0,3] = 127.
|
||||
// V[0,3] is at offset v_start + 0 * 4 + 3 = v_start + 3.
|
||||
EXPECT_EQ(data[v_start + 3], 127);
|
||||
}
|
||||
|
||||
TEST(KVEncodingTest, LayoutValidationInt8TwoTranspositions) {
|
||||
constexpr size_t kTileSize = 32;
|
||||
constexpr size_t qkv_dim = 4;
|
||||
gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8TwoTranspositions;
|
||||
|
||||
DecodedTile original(qkv_dim, kTileSize);
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.k_elem(token, dim) = dim * kTileSize + token + 1;
|
||||
}
|
||||
}
|
||||
for (size_t token = 0; token < kTileSize; ++token) {
|
||||
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
||||
original.v_elem(token, dim) =
|
||||
token * qkv_dim + dim + 1 + qkv_dim * kTileSize;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = GetTileSizeBytes(encoding, qkv_dim).value();
|
||||
std::vector<char> encoded(size);
|
||||
|
||||
ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim,
|
||||
hwy::Span<char>(encoded.data(), encoded.size())));
|
||||
|
||||
const int8_t* data = reinterpret_cast<const int8_t*>(encoded.data());
|
||||
|
||||
// K transposed: [qkv_dim/2, tile_size, 2]
|
||||
// K[0,0] = 1. Max for t=0 is 97. Scale = 97/127.
|
||||
// Quantized K[0,0] = 1.
|
||||
// K[1,0] = 33. Quantized K[1,0] = 33 / (97/127) = 43.14 -> 43.
|
||||
// K[1,0] is at offset 1.
|
||||
EXPECT_EQ(data[0], 1);
|
||||
EXPECT_EQ(data[1], 43);
|
||||
|
||||
// V transposed: [tile_size/2, qkv_dim, 2]
|
||||
size_t v_start = qkv_dim * kTileSize;
|
||||
// V[0,0] = 129. Max for t=0 is 132. Scale = 132/127.
|
||||
// Quantized V[0,0] = round(129 * 127 / 132) = 124.
|
||||
// V[1,0] = 133. Max for t=1 is 136. Scale = 136/127.
|
||||
// Quantized V[1,0] = round(133 * 127 / 136) = 124.
|
||||
// In transposed layout, V[0,0] is at v_start. V[1,0] is at v_start + 1.
|
||||
EXPECT_EQ(data[v_start], 124);
|
||||
EXPECT_EQ(data[v_start + 1], 124);
|
||||
|
||||
// V[1,3] = 136. Max for t=1 is 136. Quantized = 127.
|
||||
// Offset in transposed V: t/2*8 + d*2 + t%2.
|
||||
// For t=1, d=3: 0*8 + 3*2 + 1 = 7.
|
||||
EXPECT_EQ(data[v_start + 7], 127);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
Loading…
Reference in New Issue