#include "gemma/kv_transcoding.h" #include #include #include #include #include "compression/types.h" #include "gemma/activations.h" #include "gemma/configs.h" #include "gemma/kv_cache.h" #include "util/basics.h" #include "hwy/base.h" #include "hwy/highway.h" namespace gcpp { std::optional GetTileSizeBytes(gcpp::KVEncoding encoding, size_t qkv_dim) { constexpr size_t kTileSize = gcpp::KVCache::kTileSize; switch (encoding) { case gcpp::KVEncoding::kInt8: case gcpp::KVEncoding::kInt8TwoTranspositions: return qkv_dim * kTileSize * 2 * sizeof(int8_t) + kTileSize * 2 * sizeof(gcpp::KV_microscale_t); case gcpp::KVEncoding::kBF16: case gcpp::KVEncoding::kBF16TwoTranspositions: return qkv_dim * kTileSize * 2 * sizeof(gcpp::BF16); case gcpp::KVEncoding::kF32: case gcpp::KVEncoding::kF32TwoTranspositions: return qkv_dim * kTileSize * 2 * sizeof(float); default: return std::nullopt; } } namespace { constexpr size_t kTileSize = gcpp::KVCache::kTileSize; inline size_t KOffset(bool transposed, size_t qkv_dim, size_t dim, size_t token) { HWY_DASSERT(dim < qkv_dim && token < kTileSize); return transposed ? ((dim / 2) * kTileSize * 2 + token * 2 + (dim % 2)) : (dim * kTileSize + token); } inline size_t VOffset(bool transposed, size_t qkv_dim, size_t dim, size_t token) { HWY_DASSERT(dim < qkv_dim && token < kTileSize); return transposed ? ((token / 2) * qkv_dim * 2 + dim * 2 + (token % 2)) : (token * qkv_dim + dim); } int8_t Quantize(float v, float inv_scale) { float scaled = v * inv_scale; if (scaled > 127.0f) return 127; if (scaled < -127.0f) return -127; return hwy::ConvertScalarTo(scaled); } template inline void DecodeTileWithFn(size_t qkv_dim, DecodedTile* out, const DecodeKFn& decode_k, const DecodeVFn& decode_v) { for (size_t token = 0; token < kTileSize; ++token) { for (size_t dim = 0; dim < qkv_dim; ++dim) { out->k_elem(token, dim) = decode_k(dim, token); } } for (size_t token = 0; token < kTileSize; ++token) { for (size_t dim = 0; dim < qkv_dim; ++dim) { out->v_elem(token, dim) = decode_v(dim, token); } } } template inline void EncodeTileWithFn(size_t qkv_dim, const DecodedTile& decoded, const EncodeKFn& encode_k, const EncodeVFn& encode_v) { for (size_t token = 0; token < kTileSize; ++token) { for (size_t dim = 0; dim < qkv_dim; ++dim) { encode_k(dim, token, decoded.k_elem(token, dim)); } } for (size_t token = 0; token < kTileSize; ++token) { for (size_t dim = 0; dim < qkv_dim; ++dim) { encode_v(dim, token, decoded.v_elem(token, dim)); } } } void EncodeTileF32(bool transposed, size_t qkv_dim, const DecodedTile& decoded, hwy::Span out_encoded_tile_data) { float* data = HWY_RCAST_ALIGNED(float*, out_encoded_tile_data.data()); const size_t v_start = qkv_dim * kTileSize; EncodeTileWithFn( qkv_dim, decoded, [&](size_t dim, size_t token, float val) HWY_ATTR { data[KOffset(transposed, qkv_dim, dim, token)] = val; }, [&](size_t dim, size_t token, float val) HWY_ATTR { data[v_start + VOffset(transposed, qkv_dim, dim, token)] = val; }); } void EncodeTileBF16(bool transposed, size_t qkv_dim, const DecodedTile& decoded, hwy::Span out_encoded_tile_data) { gcpp::BF16* data = HWY_RCAST_ALIGNED(gcpp::BF16*, out_encoded_tile_data.data()); const size_t v_start = qkv_dim * kTileSize; EncodeTileWithFn( qkv_dim, decoded, [&](size_t dim, size_t token, float val) HWY_ATTR { data[KOffset(transposed, qkv_dim, dim, token)] = hwy::ConvertScalarTo(val); }, [&](size_t dim, size_t token, float val) HWY_ATTR { data[v_start + VOffset(transposed, qkv_dim, dim, token)] = hwy::ConvertScalarTo(val); }); } void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded, hwy::Span out_encoded_tile_data) { int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data()); int8_t* v_data = k_data + qkv_dim * kTileSize; gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED(gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim); gcpp::KV_microscale_t* k_scales = scales; gcpp::KV_microscale_t* v_scales = scales + kTileSize; AlignedFloatVector k_max_abs(kTileSize, 0.0f); AlignedFloatVector v_max_abs(kTileSize, 0.0f); for (size_t token = 0; token < kTileSize; ++token) { for (size_t dim = 0; dim < qkv_dim; ++dim) { k_max_abs[token] = std::max(k_max_abs[token], std::abs(decoded.k_elem(token, dim))); } } for (size_t token = 0; token < kTileSize; ++token) { for (size_t dim = 0; dim < qkv_dim; ++dim) { v_max_abs[token] = std::max(v_max_abs[token], std::abs(decoded.v_elem(token, dim))); } } AlignedFloatVector inv_scales_k(kTileSize); AlignedFloatVector inv_scales_v(kTileSize); for (size_t token = 0; token < kTileSize; ++token) { float scale_k = k_max_abs[token] == 0.0f ? 1.0f : k_max_abs[token] / 127.0f; k_scales[token] = hwy::ConvertScalarTo(scale_k); inv_scales_k[token] = 1.0f / scale_k; float scale_v = v_max_abs[token] == 0.0f ? 1.0f : v_max_abs[token] / 127.0f; v_scales[token] = hwy::ConvertScalarTo(scale_v); inv_scales_v[token] = 1.0f / scale_v; } EncodeTileWithFn( qkv_dim, decoded, [&](size_t dim, size_t token, float val) HWY_ATTR { k_data[KOffset(transposed, qkv_dim, dim, token)] = Quantize(val, inv_scales_k[token]); }, [&](size_t dim, size_t token, float val) HWY_ATTR { v_data[VOffset(transposed, qkv_dim, dim, token)] = Quantize(val, inv_scales_v[token]); }); } void DecodeTileF32(bool transposed, size_t qkv_dim, hwy::Span encoded_tile_data, DecodedTile* out) { const float* data = HWY_RCAST_ALIGNED(const float*, encoded_tile_data.data()); const size_t v_start = qkv_dim * kTileSize; DecodeTileWithFn( qkv_dim, out, [&](size_t dim, size_t token) HWY_ATTR { return data[KOffset(transposed, qkv_dim, dim, token)]; }, [&](size_t dim, size_t token) HWY_ATTR { return data[v_start + VOffset(transposed, qkv_dim, dim, token)]; }); } void DecodeTileBF16(bool transposed, size_t qkv_dim, hwy::Span encoded_tile_data, DecodedTile* out) { const gcpp::BF16* data = HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded_tile_data.data()); const size_t v_start = qkv_dim * kTileSize; DecodeTileWithFn( qkv_dim, out, [&](size_t dim, size_t token) HWY_ATTR { return hwy::ConvertScalarTo( data[KOffset(transposed, qkv_dim, dim, token)]); }, [&](size_t dim, size_t token) HWY_ATTR { return hwy::ConvertScalarTo( data[v_start + VOffset(transposed, qkv_dim, dim, token)]); }); } void DecodeTileInt8(bool transposed, size_t qkv_dim, hwy::Span encoded_tile_data, DecodedTile* out) { const int8_t* k_data = HWY_RCAST_ALIGNED(const int8_t*, encoded_tile_data.data()); const int8_t* v_data = k_data + qkv_dim * kTileSize; const gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED( const gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim); const gcpp::KV_microscale_t* k_scales = scales; const gcpp::KV_microscale_t* v_scales = scales + kTileSize; DecodeTileWithFn( qkv_dim, out, [&](size_t dim, size_t token) HWY_ATTR { float scale = hwy::ConvertScalarTo(k_scales[token]); return k_data[KOffset(transposed, qkv_dim, dim, token)] * scale; }, [&](size_t dim, size_t token) HWY_ATTR { float scale = hwy::ConvertScalarTo(v_scales[token]); return v_data[VOffset(transposed, qkv_dim, dim, token)] * scale; }); } } // namespace bool IsTransposed(KVEncoding encoding) { switch (encoding) { case KVEncoding::kF32TwoTranspositions: case KVEncoding::kBF16TwoTranspositions: case KVEncoding::kInt8TwoTranspositions: return true; default: return false; } } hwy::AlignedUniquePtr AllocateEncodedTile(KVEncoding encoding, size_t qkv_dim) { std::optional size = GetTileSizeBytes(encoding, qkv_dim); if (!size.has_value()) return hwy::AlignedUniquePtr(); return hwy::MakeUniqueAlignedArray(*size); } bool DecodeTile(KVEncoding encoding, hwy::Span encoded_tile_data, size_t qkv_dim, DecodedTile* out) { std::optional required_size_or = GetTileSizeBytes(encoding, qkv_dim); if (!required_size_or.has_value()) return false; size_t required_size = *required_size_or; if (encoded_tile_data.size() < required_size) { return false; } bool transposed = IsTransposed(encoding); switch (encoding) { case gcpp::KVEncoding::kF32: case gcpp::KVEncoding::kF32TwoTranspositions: { DecodeTileF32(transposed, qkv_dim, encoded_tile_data, out); return true; } case gcpp::KVEncoding::kBF16: case gcpp::KVEncoding::kBF16TwoTranspositions: { DecodeTileBF16(transposed, qkv_dim, encoded_tile_data, out); return true; } case gcpp::KVEncoding::kInt8: case gcpp::KVEncoding::kInt8TwoTranspositions: { DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out); return true; } default: return false; } } bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded, size_t qkv_dim, hwy::Span out_encoded_tile_data) { std::optional required_size_or = GetTileSizeBytes(encoding, qkv_dim); if (!required_size_or.has_value()) return false; size_t required_size = *required_size_or; if (out_encoded_tile_data.size() < required_size) { return false; } bool transposed = IsTransposed(encoding); switch (encoding) { case gcpp::KVEncoding::kF32: case gcpp::KVEncoding::kF32TwoTranspositions: { EncodeTileF32(transposed, qkv_dim, decoded, out_encoded_tile_data); return true; } case gcpp::KVEncoding::kBF16: case gcpp::KVEncoding::kBF16TwoTranspositions: { EncodeTileBF16(transposed, qkv_dim, decoded, out_encoded_tile_data); return true; } case gcpp::KVEncoding::kInt8: case gcpp::KVEncoding::kInt8TwoTranspositions: { EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data); return true; } default: return false; } } bool TranscodeTile(gcpp::KVEncoding src_encoding, hwy::Span src_data, gcpp::KVEncoding dst_encoding, hwy::Span dst_data, size_t qkv_dim) { DecodedTile decoded(qkv_dim, kTileSize); if (!DecodeTile(src_encoding, src_data, qkv_dim, &decoded)) return false; return EncodeTile(dst_encoding, decoded, qkv_dim, dst_data); } } // namespace gcpp