mirror of https://github.com/google/gemma.cpp.git
Int8 + microscaling support for kv cache formats.
Right now multiplication is done by converting to corresponding float format. Can yield up to 2x improvements for membw constrained shapes PiperOrigin-RevId: 880748493
This commit is contained in:
parent
d2806fb1dd
commit
029cfd0b33
|
|
@ -444,6 +444,142 @@ struct CompressTraits<SfpStream> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<int8_t> {
|
||||
using Packed = int8_t;
|
||||
|
||||
static size_t CompressBound(size_t num) { return num * sizeof(Packed); }
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& /*tls*/,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
const hn::Repartition<int32_t, DF> di32;
|
||||
const hn::Repartition<int16_t, DF> di16;
|
||||
const hn::Repartition<int8_t, DF> di8;
|
||||
const auto di16_16 = hn::Half<decltype(di16)>();
|
||||
const auto di8_16 = hn::Half<decltype(di8)>();
|
||||
using VF = hn::Vec<DF>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * NF) {
|
||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||
const VF v0 = hn::LoadU(df, raw + i);
|
||||
const VF v1 = hn::LoadU(df, raw + i + NF);
|
||||
const auto vi32_0 = hn::NearestInt(v0);
|
||||
const auto vi32_1 = hn::NearestInt(v1);
|
||||
const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1);
|
||||
const auto vi8 = hn::OrderedDemote2To(
|
||||
di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16));
|
||||
hn::StoreU(vi8, di8_16, packed.ptr + packed_ofs + i);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
if (remaining > 0) {
|
||||
HWY_ALIGN float buf[2 * NF];
|
||||
hwy::ZeroBytes(buf, 2 * NF * sizeof(float));
|
||||
for (size_t j = 0; j < remaining; ++j) buf[j] = raw[i + j];
|
||||
const VF v0 = hn::LoadU(df, buf);
|
||||
const VF v1 = hn::LoadU(df, buf + NF);
|
||||
const auto vi32_0 = hn::NearestInt(v0);
|
||||
const auto vi32_1 = hn::NearestInt(v1);
|
||||
const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1);
|
||||
const auto vi8 = hn::OrderedDemote2To(
|
||||
di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16));
|
||||
hn::StoreN(vi8, di8_16, packed.ptr + packed_ofs + i, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static float ToFloatSlow(const Packed x) { return static_cast<float>(x); }
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<DF>& raw0,
|
||||
hn::Vec<DF>& raw1) {
|
||||
const hn::Repartition<int32_t, DF> di32;
|
||||
const hn::Repartition<int16_t, DF> di16;
|
||||
const hn::Rebind<int8_t, decltype(di16)> di8_half;
|
||||
|
||||
const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs);
|
||||
const auto vec_i16 = hn::PromoteTo(di16, vec_i8);
|
||||
const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16);
|
||||
const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16);
|
||||
|
||||
raw0 = hn::ConvertTo(df, vec_i32_0);
|
||||
raw1 = hn::ConvertTo(df, vec_i32_1);
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void Load2(DBF dbf, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<DBF>& raw0,
|
||||
hn::Vec<DBF>& raw1) {
|
||||
const hn::Repartition<float, DBF> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
VF f0, f1, f2, f3;
|
||||
Load2(df, packed, packed_ofs, f0, f1);
|
||||
Load2(df, packed, packed_ofs + 2 * NF, f2, f3);
|
||||
|
||||
raw0 = hn::OrderedDemote2To(dbf, f0, f1);
|
||||
raw1 = hn::OrderedDemote2To(dbf, f2, f3);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
float* HWY_RESTRICT raw, size_t num) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * NF) {
|
||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||
VF raw0, raw1;
|
||||
Load2(df, packed, packed_ofs + i, raw0, raw1);
|
||||
hn::StoreU(raw0, df, raw + i);
|
||||
hn::StoreU(raw1, df, raw + i + NF);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
for (size_t j = 0; j < remaining; ++j) {
|
||||
raw[i + j] = static_cast<float>(packed.ptr[packed_ofs + i + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
BF16* HWY_RESTRICT raw, size_t num) {
|
||||
const hn::Repartition<float, DBF> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
size_t i = 0;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
if (num >= NBF) {
|
||||
for (; i <= num - NBF; i += NBF) {
|
||||
hn::Vec<decltype(df)> f0, f1;
|
||||
Load2(df, packed, packed_ofs + i, f0, f1);
|
||||
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
|
||||
hn::StoreU(vbf, dbf, raw + i);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
if (remaining > 0) {
|
||||
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining);
|
||||
auto f0 = hn::LoadU(df, buf);
|
||||
auto f1 = hn::LoadU(df, buf + NF);
|
||||
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
|
||||
hn::StoreN(vbf, dbf, raw + i, remaining);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Integer quantization.
|
||||
template <>
|
||||
struct CompressTraits<I8Stream> {
|
||||
|
|
|
|||
|
|
@ -126,6 +126,8 @@ struct TestDecompress2 {
|
|||
HWY_ASSERT(stats.L1().Max() <= 0.08f);
|
||||
HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1()));
|
||||
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
|
||||
HWY_ASSERT(stats.L1().Max() <= 0.6f);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
|
||||
}
|
||||
|
|
@ -200,6 +202,8 @@ struct TestShortLengths {
|
|||
HWY_ASSERT(stats.L1().Max() <= 0.14f);
|
||||
HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1()));
|
||||
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
|
||||
HWY_ASSERT(stats.L1().Max() <= 0.6f);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -192,6 +192,11 @@ constexpr bool IsF32() {
|
|||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsInt8() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsBF16() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
|
||||
|
|
@ -231,12 +236,13 @@ enum class Type {
|
|||
kI8,
|
||||
kU16,
|
||||
kU8,
|
||||
kInt8,
|
||||
};
|
||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||
// change, though new ones may be added.
|
||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "u32", "u64",
|
||||
"i8", "u16", "u8"};
|
||||
static constexpr const char* kTypeStrings[] = {
|
||||
"unknown", "f32", "bf16", "sfp", "nuq", "f64",
|
||||
"u32", "u64", "i8", "u16", "u8", "int8"};
|
||||
static constexpr size_t kNumTypes =
|
||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||
static constexpr size_t kTypeBits[] = {
|
||||
|
|
@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = {
|
|||
8 * sizeof(I8Stream),
|
||||
8 * sizeof(uint16_t),
|
||||
8 * sizeof(uint8_t),
|
||||
8 * sizeof(int8_t),
|
||||
};
|
||||
|
||||
static inline bool EnumValid(Type type) {
|
||||
|
|
@ -281,6 +288,8 @@ constexpr Type TypeEnum() {
|
|||
return Type::kU16;
|
||||
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
|
||||
return Type::kU8;
|
||||
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
|
||||
return Type::kInt8;
|
||||
} else {
|
||||
return Type::kUnknown;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1260,6 +1260,52 @@ static HWY_NOINLINE void ApplyMasking(
|
|||
}
|
||||
}
|
||||
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
|
||||
VF& x0_p1, VF& x1_p0, VF& x1_p1,
|
||||
VF& x2_p0, VF& x2_p1, VF& x3_p0,
|
||||
VF& x3_p1, VF& x4_p0, VF& x4_p1,
|
||||
VF& x5_p0, VF& x5_p1, VF& x6_p0,
|
||||
VF& x6_p1, VF& x7_p0, VF& x7_p1) {
|
||||
const size_t kTileSize = hn::Lanes(df);
|
||||
const PackedSpan<const BF16> scales_span =
|
||||
MakeConstSpan(scales, 2 * kTileSize);
|
||||
VF scales_p0, scales_p1;
|
||||
Decompress2(df, scales_span, 0, scales_p0, scales_p1);
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
x0_p0 = hn::Mul(x0_p0, scales_p0);
|
||||
x0_p1 = hn::Mul(x0_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
x1_p0 = hn::Mul(x1_p0, scales_p0);
|
||||
x1_p1 = hn::Mul(x1_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
x2_p0 = hn::Mul(x2_p0, scales_p0);
|
||||
x2_p1 = hn::Mul(x2_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
x3_p0 = hn::Mul(x3_p0, scales_p0);
|
||||
x3_p1 = hn::Mul(x3_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
x4_p0 = hn::Mul(x4_p0, scales_p0);
|
||||
x4_p1 = hn::Mul(x4_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
x5_p0 = hn::Mul(x5_p0, scales_p0);
|
||||
x5_p1 = hn::Mul(x5_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
x6_p0 = hn::Mul(x6_p0, scales_p0);
|
||||
x6_p1 = hn::Mul(x6_p1, scales_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
x7_p0 = hn::Mul(x7_p0, scales_p0);
|
||||
x7_p1 = hn::Mul(x7_p1, scales_p1);
|
||||
}
|
||||
}
|
||||
|
||||
// Performs tiled flash attention for arbitrary number of queries
|
||||
// It depends on kv being tiled.
|
||||
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
|
||||
|
|
@ -1400,6 +1446,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
false,
|
||||
"Query type type not supported, only float and BF16 are supported");
|
||||
}
|
||||
// microscaling
|
||||
// TODO: Change to more generic function to inform if we should use
|
||||
// microscaling or not.
|
||||
constexpr bool kUseMicroScaling = IsInt8<KV_T>();
|
||||
if constexpr (kUseMicroScaling) {
|
||||
// After end of the tile, we have kTileSize * 2 bfloat16 for the
|
||||
// microscaling scales for K and V.
|
||||
const BF16* microscaling_scales_k =
|
||||
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
|
||||
pos_in_tile;
|
||||
MultiplyByScale<kNumQueries>(df, microscaling_scales_k, x_0_p_0, x_0_p_1,
|
||||
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
|
||||
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
|
||||
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
}
|
||||
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
constexpr int kSecondHalfAmountOfQueries =
|
||||
|
|
@ -1433,6 +1494,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
|
||||
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
|
||||
kNumQueriesPerGroup);
|
||||
if constexpr (kUseMicroScaling) {
|
||||
const BF16* microscaling_scales_v =
|
||||
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
|
||||
kTileSize + pos_in_tile;
|
||||
MultiplyByScale<kNumQueries>(df, microscaling_scales_v, x_0_p_0, x_0_p_1,
|
||||
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
|
||||
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
|
||||
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
}
|
||||
if constexpr (IsF32<Q_T>()) {
|
||||
MulByConstAndAddTileUpTo8<kNumQueries>(
|
||||
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
|
|
|
|||
|
|
@ -492,6 +492,139 @@ void TestTiledFlashAttentionBF16() {
|
|||
}
|
||||
}
|
||||
|
||||
void TestTiledFlashAttentionInt8() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
|
||||
// tiles size to test the padding logic.
|
||||
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
|
||||
float att_cap = 10.0f;
|
||||
int num_queries = 8;
|
||||
int num_queries_per_timestep = 4;
|
||||
int num_tokens = num_queries / num_queries_per_timestep;
|
||||
int kv_seq_end =
|
||||
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
|
||||
int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
|
||||
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
|
||||
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
|
||||
|
||||
MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
|
||||
// fill in kvs with predictable, synthetic data
|
||||
for (int i = 0; i < padded_kv_seq_len; ++i) {
|
||||
int tile_idx = i / gcpp::KVCache::kTileSize;
|
||||
int in_tile_offset = i % gcpp::KVCache::kTileSize;
|
||||
int8_t* tile_ptr = kv.Row(tile_idx);
|
||||
BF16* scales_ptr = HWY_RCAST_ALIGNED(
|
||||
BF16*, tile_ptr + 2 * qkv_dim * gcpp::KVCache::kTileSize);
|
||||
|
||||
// Generate float values for K and V
|
||||
std::vector<float> k_vals(qkv_dim);
|
||||
std::vector<float> v_vals(qkv_dim);
|
||||
float max_abs_k = 0.0f;
|
||||
float max_abs_v = 0.0f;
|
||||
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
k_vals[j] = 0.01f * (i + 1) / (j + 1);
|
||||
v_vals[j] = 0.02f * (i + 1) / (j + 1);
|
||||
max_abs_k = std::max(max_abs_k, std::abs(k_vals[j]));
|
||||
max_abs_v = std::max(max_abs_v, std::abs(v_vals[j]));
|
||||
}
|
||||
|
||||
// Quantize K
|
||||
float scale_k = max_abs_k / 127.0f;
|
||||
if (scale_k == 0.0f) scale_k = 1.0f;
|
||||
scales_ptr[in_tile_offset] = hwy::ConvertScalarTo<BF16>(scale_k);
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
int val = std::round(k_vals[j] / scale_k);
|
||||
val = std::max(-127, std::min(127, val));
|
||||
tile_ptr[j * gcpp::KVCache::kTileSize + in_tile_offset] =
|
||||
static_cast<int8_t>(val);
|
||||
}
|
||||
|
||||
// Quantize V
|
||||
float scale_v = max_abs_v / 127.0f;
|
||||
if (scale_v == 0.0f) scale_v = 1.0f;
|
||||
scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset] =
|
||||
hwy::ConvertScalarTo<BF16>(scale_v);
|
||||
size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
int val = std::round(v_vals[j] / scale_v);
|
||||
val = std::max(-127, std::min(127, val));
|
||||
tile_ptr[v_offset + in_tile_offset * qkv_dim + j] =
|
||||
static_cast<int8_t>(val);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> q_float(4 * qkv_dim);
|
||||
std::vector<float> q_float2(4 * qkv_dim);
|
||||
// fill in qs with predictable, synthetic data
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int j = 0; j < qkv_dim; j++) {
|
||||
float val_1 = 0.01f * (i + 1) / (j + 1);
|
||||
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
|
||||
q_float[j * 4 + i] = val_1;
|
||||
q_float2[j * 4 + i] = val_2;
|
||||
}
|
||||
}
|
||||
const float* q_T[2] = {q_float.data(), q_float2.data()};
|
||||
|
||||
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
|
||||
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
|
||||
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
|
||||
std::vector<float> max_logits(num_queries_rounded_to_laness);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(i),
|
||||
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
|
||||
exp_denominator_sums[i] = 0.0f;
|
||||
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
|
||||
}
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
|
||||
start_pos_per_query.reserve(num_queries);
|
||||
last_pos_per_query.reserve(num_queries);
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
ssize_t query_last_pos = kv_seq_end + token_idx;
|
||||
ssize_t query_start_pos =
|
||||
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
|
||||
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
|
||||
++q_head_idx) {
|
||||
start_pos_per_query.push_back(query_start_pos);
|
||||
last_pos_per_query.push_back(query_last_pos);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::Span<const MatPtr> kvs(&kv, 1);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kvs, num_queries, hwy::Span<const float*>(q_T, 2),
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
|
||||
exp_denominator_sums.data(), max_logits.data());
|
||||
|
||||
// TODO: Replace with Other implementation for generating goldens.
|
||||
// Current values are taken from a point in time where code was run with gemma
|
||||
// and output looked good. Not ideal but should be good enough to test the
|
||||
// plumbing and detect regressions.
|
||||
PrintMatPtr(att_out);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-2f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
@ -502,6 +635,9 @@ HWY_AFTER_NAMESPACE();
|
|||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(FlashAttentionTest);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttention);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionBF16);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ struct RuntimeConfig {
|
|||
// If not set, it will be set based on the attention_impl.
|
||||
// F32 for tiled
|
||||
// BF16 for tiled bf16
|
||||
// Int8 works for both tiled and tiled bf16.
|
||||
// If you want to use type other than F32 or BF16, you might need to update
|
||||
// call upcasted.
|
||||
std::optional<Type> kv_cache_type = {};
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
const size_t num_tiles =
|
||||
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
|
||||
tiled_seq_len = num_tiles * kTileSize;
|
||||
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
|
||||
Type kv_cache_type;
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
|| hwy::IsSame<KV_t, BF16>()) {
|
||||
|
|
@ -92,6 +91,11 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
} else {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
|
||||
}
|
||||
|
||||
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
|
||||
if (kv_cache_type == Type::kInt8) {
|
||||
tile_length += 2 * sizeof(BF16) * kTileSize;
|
||||
}
|
||||
auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size,
|
||||
size_t max_seq_len) {
|
||||
return hwy::DivCeil(
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
namespace gcpp {
|
||||
|
||||
using KV_t = BF16;
|
||||
using KV_microscale_t = BF16;
|
||||
struct KVCache;
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
|
|
|
|||
|
|
@ -69,6 +69,27 @@ static HWY_INLINE void MergeOnlineSoftmax(
|
|||
accumulator_softmax_d = d_new;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T AbsMaxOfSpan(hwy::Span<const T> span) {
|
||||
hn::ScalableTag<T> dt;
|
||||
using VT = hn::Vec<decltype(dt)>;
|
||||
VT max_vec = hn::Set(dt, 0.0f);
|
||||
const size_t lanes = hn::Lanes(dt);
|
||||
size_t i = 0;
|
||||
// Process full vectors using LoadU.
|
||||
for (; i + lanes <= span.size(); i += lanes) {
|
||||
const VT vec = hn::Abs(hn::LoadU(dt, span.data() + i));
|
||||
max_vec = hn::Max(max_vec, vec);
|
||||
}
|
||||
// Process remaining elements using LoadN.
|
||||
const size_t remaining = span.size() - i;
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const VT vec = hn::Abs(hn::LoadN(dt, span.data() + i, remaining));
|
||||
max_vec = hn::Max(max_vec, vec);
|
||||
}
|
||||
return hn::ReduceMax(dt, max_vec);
|
||||
}
|
||||
|
||||
// Forked from ComputeQKV. But it stores the K/V in the tiled format
|
||||
// KV_T is type stored in the KV cache (typically float or BF16).
|
||||
template <typename KV_T>
|
||||
|
|
@ -168,9 +189,9 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
const float* kv_row =
|
||||
kv_out_data +
|
||||
(token_in_tile_idx * qbatch.Size() + query_idx) * kv_out_cols;
|
||||
const float* k_ptr = kv_row + kv_head * 2 * qkv_dim;
|
||||
const float* v_ptr = kv_row + kv_head * 2 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k_ptr, k_f32, qkv_dim * sizeof(float));
|
||||
const float* k_values = kv_row + kv_head * 2 * qkv_dim;
|
||||
const float* v_values = kv_row + kv_head * 2 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k_values, k_f32, qkv_dim * sizeof(float));
|
||||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, k_f32,
|
||||
|
|
@ -183,7 +204,53 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
/*mul=*/1.0f);
|
||||
|
||||
const size_t in_tile_idx = current_pos_mod % KVCache::kTileSize;
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
// `v_cache_values` is a pointer to the V data that will be
|
||||
// compressed and stored in the KV cache. By default, it points to
|
||||
// the raw `v_values`.
|
||||
const float* v_cache_values = v_values;
|
||||
// `v_buf` is a temporary buffer used only when quantizing V values
|
||||
// to int8_t.
|
||||
HWY_ALIGN float v_buf[kMaxQKVDim];
|
||||
|
||||
if constexpr (IsInt8<KV_T>()) {
|
||||
BF16* scales_ptr = HWY_RCAST_ALIGNED(
|
||||
BF16*, tile_ptr + 2 * qkv_dim * KVCache::kTileSize);
|
||||
|
||||
auto scale_and_store = [&](float* values, int dim,
|
||||
size_t scale_idx) HWY_ATTR {
|
||||
const float max_abs =
|
||||
AbsMaxOfSpan(hwy::Span<const float>(values, dim));
|
||||
float scale = max_abs / 127.0f;
|
||||
if (scale == 0.0f) scale = 1.0f;
|
||||
scales_ptr[scale_idx] = hwy::ConvertScalarTo<BF16>(scale);
|
||||
const float inv_scale = 1.0f / scale;
|
||||
const hn::Vec<decltype(df)> v_inv_scale =
|
||||
hn::Set(df, inv_scale);
|
||||
const size_t lanes = hn::Lanes(df);
|
||||
size_t i = 0;
|
||||
for (; i + lanes <= dim; i += lanes) {
|
||||
hn::StoreU(hn::Mul(hn::LoadU(df, values + i), v_inv_scale),
|
||||
df, values + i);
|
||||
}
|
||||
if (HWY_UNLIKELY(i < dim)) {
|
||||
hn::StoreN(
|
||||
hn::Mul(hn::LoadN(df, values + i, dim - i), v_inv_scale),
|
||||
df, values + i, dim - i);
|
||||
}
|
||||
};
|
||||
|
||||
// K Scaling
|
||||
scale_and_store(k_f32, qkv_dim, in_tile_idx);
|
||||
|
||||
// V Scaling: Copy `v_values` to `v_buf`, scale `v_buf` in-place,
|
||||
// and then update `v_cache_values` to point to `v_buf`.
|
||||
hwy::CopyBytes(v_values, v_buf, qkv_dim * sizeof(float));
|
||||
scale_and_store(v_buf, qkv_dim, KVCache::kTileSize + in_tile_idx);
|
||||
v_cache_values = v_buf;
|
||||
}
|
||||
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16 &&
|
||||
!IsInt8<KV_T>()) {
|
||||
const int in_tile_idx_mod_2 = in_tile_idx % 2;
|
||||
for (int dim = 0; dim < qkv_dim; dim += 2) {
|
||||
const int dim_mod_2 = dim % 2;
|
||||
|
|
@ -196,16 +263,17 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
in_tile_idx * 2 + 1] = k_f32[dim + 1];
|
||||
// Pack v's in pairs
|
||||
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
|
||||
dim * 2 + in_tile_idx_mod_2] = v_ptr[dim];
|
||||
dim * 2 + in_tile_idx_mod_2] = v_cache_values[dim];
|
||||
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
|
||||
(dim + 1) * 2 + in_tile_idx_mod_2] = v_ptr[dim + 1];
|
||||
(dim + 1) * 2 + in_tile_idx_mod_2] =
|
||||
v_cache_values[dim + 1];
|
||||
}
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < qkv_dim; ++i) {
|
||||
k_tile_vec[i * KVCache::kTileSize + in_tile_idx] = k_f32[i];
|
||||
}
|
||||
Compress(v_ptr, qkv_dim, tls, tile_packed_span,
|
||||
Compress(v_cache_values, qkv_dim, tls, tile_packed_span,
|
||||
qkv_dim * (KVCache::kTileSize + in_tile_idx));
|
||||
}
|
||||
|
||||
|
|
@ -640,12 +708,21 @@ void TiledAttention(AttentionImpl attention_impl, size_t num_tokens,
|
|||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
(void)layer_config; // only used in HWY_DASSERT
|
||||
|
||||
if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) {
|
||||
ComputeQKVTransposedTile<BF16>(num_tokens, layer_idx, layer, attention_impl,
|
||||
activations, qbatch, flags, env);
|
||||
} else {
|
||||
} else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kF32) {
|
||||
ComputeQKVTransposedTile<KV_t>(num_tokens, layer_idx, layer, attention_impl,
|
||||
activations, qbatch, flags, env);
|
||||
} else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() ==
|
||||
Type::kInt8) {
|
||||
ComputeQKVTransposedTile<int8_t>(num_tokens, layer_idx, layer,
|
||||
attention_impl, activations, qbatch, flags,
|
||||
env);
|
||||
} else {
|
||||
HWY_ABORT("Unsupported KV cache type: %d",
|
||||
qbatch.KV(0).cache->compact_kv_cache_ptr.GetType());
|
||||
}
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||
layer.query_norm_scale, layer_idx, activations,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -42,7 +45,7 @@ struct AttentionTestEnv {
|
|||
int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads,
|
||||
int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx,
|
||||
int layers_total, int qbatch_size, AttentionImpl attention_impl,
|
||||
)
|
||||
std::optional<Type> kv_cache_type = {} )
|
||||
: ctx(threading_args), env(ctx) {
|
||||
layer_config.heads = num_heads;
|
||||
layer_config.kv_heads = num_kv_heads;
|
||||
|
|
@ -65,6 +68,7 @@ struct AttentionTestEnv {
|
|||
*tensor_info_registry);
|
||||
|
||||
runtime_config.attention_impl = attention_impl;
|
||||
runtime_config.kv_cache_type = kv_cache_type;
|
||||
inference_args.seq_len = kv_seq_len;
|
||||
|
||||
all_queries.Reserve(qbatch_size);
|
||||
|
|
@ -72,7 +76,8 @@ struct AttentionTestEnv {
|
|||
for (int q = 0; q < qbatch_size; ++q) {
|
||||
kv_caches.emplace_back(model_config, inference_args, runtime_config,
|
||||
ctx.allocator);
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16 &&
|
||||
kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kBF16) {
|
||||
MatPtrT<BF16> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
|
||||
for (int i = 0; i < compact_kv_cache.Rows(); ++i) {
|
||||
for (int j = 0; j < compact_kv_cache.Cols(); ++j) {
|
||||
|
|
@ -98,8 +103,65 @@ struct AttentionTestEnv {
|
|||
}
|
||||
}
|
||||
} else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
|
||||
MatPtrT<KV_t> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
|
||||
FillMatPtrT(compact_kv_cache);
|
||||
if (kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kInt8) {
|
||||
MatPtrT<int8_t> compact_kv_cache =
|
||||
kv_caches.back().compact_kv_cache_ptr;
|
||||
for (int i = 0; i < compact_kv_cache.Rows(); ++i) {
|
||||
BF16* scales_ptr = HWY_RCAST_ALIGNED(
|
||||
BF16*, compact_kv_cache.Row(i) +
|
||||
2 * qkv_dim * gcpp::KVCache::kTileSize);
|
||||
for (int in_tile_idx = 0; in_tile_idx < gcpp::KVCache::kTileSize;
|
||||
++in_tile_idx) {
|
||||
// Compute scale and fill K
|
||||
float max_k = 0.0f;
|
||||
for (int dim = 0; dim < qkv_dim; ++dim) {
|
||||
int j = dim * gcpp::KVCache::kTileSize + in_tile_idx;
|
||||
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
max_k = std::max(max_k, expected);
|
||||
}
|
||||
float scale_k = max_k / 127.0f;
|
||||
if (scale_k == 0.0f) scale_k = 1.0f;
|
||||
scales_ptr[in_tile_idx] = hwy::ConvertScalarTo<BF16>(scale_k);
|
||||
|
||||
for (int dim = 0; dim < qkv_dim; ++dim) {
|
||||
int j = dim * gcpp::KVCache::kTileSize + in_tile_idx;
|
||||
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
compact_kv_cache.Row(i)[j] =
|
||||
static_cast<int8_t>(std::round(expected / scale_k));
|
||||
}
|
||||
|
||||
// Compute scale and fill V
|
||||
float max_v = 0.0f;
|
||||
for (int dim = 0; dim < qkv_dim; ++dim) {
|
||||
int j = qkv_dim * gcpp::KVCache::kTileSize +
|
||||
in_tile_idx * qkv_dim + dim;
|
||||
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
max_v = std::max(max_v, expected);
|
||||
}
|
||||
float scale_v = max_v / 127.0f;
|
||||
if (scale_v == 0.0f) scale_v = 1.0f;
|
||||
scales_ptr[gcpp::KVCache::kTileSize + in_tile_idx] =
|
||||
hwy::ConvertScalarTo<BF16>(scale_v);
|
||||
|
||||
for (int dim = 0; dim < qkv_dim; ++dim) {
|
||||
int j = qkv_dim * gcpp::KVCache::kTileSize +
|
||||
in_tile_idx * qkv_dim + dim;
|
||||
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
compact_kv_cache.Row(i)[j] =
|
||||
static_cast<int8_t>(std::round(expected / scale_v));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (kv_caches.back().compact_kv_cache_ptr.GetType() ==
|
||||
Type::kBF16) {
|
||||
MatPtrT<BF16> compact_kv_cache =
|
||||
kv_caches.back().compact_kv_cache_ptr;
|
||||
FillMatPtrT(compact_kv_cache);
|
||||
} else {
|
||||
MatPtrT<float> compact_kv_cache =
|
||||
kv_caches.back().compact_kv_cache_ptr;
|
||||
FillMatPtrT(compact_kv_cache);
|
||||
}
|
||||
} else {
|
||||
FillMatPtrT(kv_caches.back().kv_cache);
|
||||
}
|
||||
|
|
@ -725,6 +787,50 @@ void TestAttentionMultipleTokensBF16() {
|
|||
}
|
||||
}
|
||||
|
||||
void TestAttentionMultipleTokensInt8() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 64;
|
||||
int num_kv_heads = 2;
|
||||
int num_heads = 4;
|
||||
int num_tokens = 2;
|
||||
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
|
||||
// will have 64 tokens to attend to.
|
||||
float att_cap = 10.0f;
|
||||
int layer_idx = 0;
|
||||
int layers_total = 1;
|
||||
int qbatch_size = 2;
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16;
|
||||
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
|
||||
num_heads, num_tokens, last_pos, att_cap, layer_idx,
|
||||
layers_total, qbatch_size, attention_impl,
|
||||
Type::kInt8);
|
||||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_d);
|
||||
|
||||
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
|
||||
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
|
||||
test_env.activations->attention, *test_env.qbatch,
|
||||
test_env.env, flags);
|
||||
std::cerr << "att_out\n";
|
||||
PrintMatPtr(test_env.activations->attention.att_out);
|
||||
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
AttentionMultipleTokensAttentionGoldens.data() +
|
||||
i * test_env.activations->attention.att_out.Cols(),
|
||||
test_env.activations->attention.att_out.Row(i),
|
||||
test_env.activations->attention.att_out.Cols(), 1e-1,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "att_out mismatch for query: " << i;
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -508,6 +508,11 @@ decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
|
|||
auto matptrs = make_matptr_vec(BF16{});
|
||||
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else if (type == Type::kInt8) {
|
||||
auto matptrs = make_matptr_vec(int8_t{});
|
||||
hwy::Span<const MatPtrT<int8_t>> matptrs_span(matptrs.data(),
|
||||
matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(type));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue