Add 8-bit integer quantization (I8Stream) to Gemma.cpp.

PiperOrigin-RevId: 819787856
This commit is contained in:
Phil Culliton 2025-10-15 09:24:38 -07:00 committed by Copybara-Service
parent ee18916abf
commit 503aaddd65
25 changed files with 1428 additions and 64 deletions

View File

@ -349,6 +349,7 @@ cc_library(
"ops/matmul_static_f32.cc",
"ops/matmul_static_nuq.cc",
"ops/matmul_static_sfp.cc",
"ops/matmul_static_i8.cc",
],
hdrs = [
"ops/matmul_static.h",

View File

@ -68,6 +68,7 @@ set(SOURCES
compression/compress.h
compression/nuq-inl.h
compression/sfp-inl.h
compression/int-inl.h
compression/types.h
compression/test_util-inl.h
evals/benchmark_helper.cc
@ -109,6 +110,7 @@ set(SOURCES
ops/matmul_static_f32.cc
ops/matmul_static_nuq.cc
ops/matmul_static_sfp.cc
ops/matmul_static_i8.cc
ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h

View File

@ -80,6 +80,37 @@ cc_library(
],
)
cc_library(
name = "int",
textual_hdrs = ["int-inl.h"],
deps = [
":types",
"//:basics",
"@highway//:hwy",
],
)
cc_test(
name = "int_test",
size = "small",
timeout = "long",
srcs = ["int_test.cc"],
features = ["fully_static_link"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":distortion",
":int",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)
cc_library(
name = "test_util",
textual_hdrs = [
@ -144,6 +175,7 @@ cc_library(
textual_hdrs = ["compress-inl.h"],
deps = [
":distortion",
":int",
":nuq",
":sfp",
"//:basics",
@ -182,6 +214,7 @@ cc_library(
name = "analyze",
textual_hdrs = ["analyze.h"],
deps = [
":int",
":nuq",
":sfp",
":types",

View File

@ -47,6 +47,7 @@
#include "hwy/highway.h"
// After highway.h
#include "compression/int-inl.h"
#include "compression/nuq-inl.h"
#include "compression/sfp-inl.h"
@ -416,6 +417,34 @@ struct CompressTraits<SfpStream> {
}
};
// Integer quantization.
template <>
struct CompressTraits<I8Stream> {
using Packed = I8Stream;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
IntCodec::Enc(df, raw, num, packed, packed_ofs);
}
template <class D> // Caller checks this is f32 or bf16
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0,
hn::Vec<D>& raw1) {
IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
}
template <class D, typename Raw>
static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
Raw* raw, const size_t num) {
IntCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
}
};
// Nonuniform quantization, 4.5 bits per element, two separate streams.
template <>
struct CompressTraits<NuqStream> {
@ -737,9 +766,10 @@ template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num,
const T1* HWY_RESTRICT p1,
const size_t p1_ofs,
Func&& func) {
const auto packed_inout = MakeSpan(inout, num);
const auto packed1 = MakeSpan(p1, num);
const auto packed1 = MakeSpan(p1, p1_ofs + num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
@ -749,7 +779,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
VF v0, v1;
Decompress2(df, packed_inout, i, v0, v1);
VF v10, v11;
Decompress2(df, packed1, i, v10, v11);
Decompress2(df, packed1, p1_ofs + i, v10, v11);
const VF out0 = func(df, v0, v10);
const VF out1 = func(df, v1, v11);
Compress2(df, out0, out1, packed_inout, i);
@ -765,7 +795,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
hn::Store(hn::Zero(df), df, buf_inout + NF);
hn::Store(hn::Zero(df), df, buf1 + NF);
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
DecompressAndZeroPad(df, packed1, p1_ofs + i, buf1, remaining);
const VF v0 = hn::Load(df, buf_inout);
const VF v1 = hn::Load(df, buf_inout + NF);
const VF v10 = hn::Load(df, buf1);
@ -827,10 +857,10 @@ template <class DF, typename T, typename T1, typename T2, class Func>
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1,
const T2* HWY_RESTRICT p2,
Func&& func) {
const size_t p2_ofs, Func&& func) {
const auto packed_out = MakeSpan(out, num);
const auto packed1 = MakeSpan(p1, num);
const auto packed2 = MakeSpan(p2, num);
const auto packed2 = MakeSpan(p2, p2_ofs + num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
@ -839,7 +869,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v10, v11, v20, v21;
Decompress2(df, packed1, i, v10, v11);
Decompress2(df, packed2, i, v20, v21);
Decompress2(df, packed2, p2_ofs + i, v20, v21);
const VF out0 = func(df, v10, v20);
const VF out1 = func(df, v11, v21);
Compress2(df, out0, out1, packed_out, i);
@ -856,7 +886,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
hn::Store(hn::Zero(df), df, buf1 + NF);
hn::Store(hn::Zero(df), df, buf2 + NF);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
DecompressAndZeroPad(df, packed2, i, buf2, remaining);
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF v20 = hn::Load(df, buf2);

View File

@ -243,7 +243,7 @@ class TestDecompressAndCompress {
// Uses `out` so as not to overwrite `p`.
Decompress1AndCompressInplace(
df, out.get(), num, p1.get(),
df, out.get(), num, p1.get(), /*p1_ofs=*/0,
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
@ -251,9 +251,9 @@ class TestDecompressAndCompress {
[](DF, VF v) HWY_ATTR -> VF { return v; });
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
Decompress2AndCompressTo(df, out.get(), num, p.get(), p1.get(),
[](DF, VF v, VF v1)
HWY_ATTR -> VF { return hn::Add(v, v1); });
Decompress2AndCompressTo(
df, out.get(), num, p.get(), p1.get(), /*p2_ofs=*/0,
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
Decompress3AndCompressTo(

474
compression/int-inl.h Normal file
View File

@ -0,0 +1,474 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <cstdint>
#include <cstdio>
#include "compression/types.h"
#include "util/basics.h"
#include "hwy/base.h"
#include "hwy/print-inl.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
#endif
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Encode/decode functions.
class IntCodec {
using ScaleT = hwy::bfloat16_t;
static constexpr size_t kGroupSize = I8Stream::kGroupSize;
// Offset (in bytes) of a group's start for packed_ofs (in elements) within a
// set of groups.
static constexpr size_t GroupByteOffset(size_t packed_ofs) {
const size_t kBytesPerGroup = (2 * sizeof(ScaleT)) + kGroupSize;
return (packed_ofs / kGroupSize) * kBytesPerGroup;
}
public:
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void DequantizeGroup(
DBF dbf, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
hwy::bfloat16_t* HWY_RESTRICT raw, size_t num) {
using T = ScaleT;
const hn::ScalableTag<float> df;
const hn::Rebind<int32_t, decltype(df)> di32;
const hn::Rebind<int16_t, decltype(di32)> di16;
const hn::Rebind<int8_t, decltype(di16)> di8;
const hn::Twice<hn::Rebind<hwy::bfloat16_t, decltype(df)>> dbf16;
const size_t N = hn::Lanes(di8);
const size_t N16 = hn::Lanes(dbf16);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
T inv_scale, zeropoint;
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
sizeof(T));
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
&zeropoint, sizeof(T));
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
VF inv_scale_vec = hn::Set(df, inv_scale_f);
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
// Then iterate over remainder of packed, extracting num / N vectors and
// inserting into raw.
const size_t g_num = HWY_MIN(num, kGroupSize);
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
size_t i = 0;
for (i = 0; i + 4 * N <= g_num; i += 4 * N) {
const VI8 val0 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N);
const VI8 val1 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N);
const VI8 val2 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 2 * N);
const VI8 val3 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 3 * N);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
const VF val1_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
const VF val2_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val2)));
const VF val3_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val3)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
VF dequantized_val2 = hn::MulAdd(inv_scale_vec, val2_f, zeroscale_vec);
VF dequantized_val3 = hn::MulAdd(inv_scale_vec, val3_f, zeroscale_vec);
hn::StoreU(
hn::OrderedDemote2To(dbf16, dequantized_val0, dequantized_val1),
dbf16, raw + i + 0 * N16);
hn::StoreU(
hn::OrderedDemote2To(dbf16, dequantized_val2, dequantized_val3),
dbf16, raw + i + 1 * N16);
}
for (; i + N <= g_num; i += N) {
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
const hn::Rebind<hwy::bfloat16_t, decltype(df)> dbf_half;
hn::StoreU(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i);
}
if (i < g_num) {
const size_t remaining = g_num - i;
const VI8 val0 =
hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
const hn::Rebind<hwy::bfloat16_t, decltype(df)> dbf_half;
hn::StoreN(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i,
remaining);
}
}
// Dequantizes `num` floats from `packed` into `raw`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset
// within it, in number of elements. Scaling values are interleaved with int
// values to allow for easier unpacking.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DequantizeGroup(
DF df, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
float* HWY_RESTRICT raw, size_t num) {
using T = ScaleT;
const hn::Rebind<int32_t, decltype(df)> di32;
const hn::Rebind<int16_t, decltype(di32)> di16;
const hn::Rebind<int8_t, decltype(di16)> di8;
const hn::Rebind<int8_t, decltype(df)> df8;
const size_t N = hn::Lanes(di8);
const size_t N32 = hn::Lanes(df);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
// HWY_ASSERT(num % 2 * N == 0);
// Load scale and zero point from the beginning - ensure correct pointer
// offset.
T inv_scale, zeropoint;
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
sizeof(T));
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
&zeropoint, sizeof(T));
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
VF inv_scale_vec = hn::Set(df, inv_scale_f);
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
// Then iterate over remainder of packed, extracting num / N vectors and
// inserting into raw.
const size_t g_num = HWY_MIN(num, kGroupSize);
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
size_t i = 0;
for (; i + 2 * N <= g_num; i += 2 * N) {
const VI8 val0 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N);
const VI8 val1 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
const VF val1_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
hn::StoreU(dequantized_val0, df, raw + i + 0 * N32);
hn::StoreU(dequantized_val1, df, raw + i + 1 * N32);
}
for (; i + N <= g_num; i += N) {
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
hn::StoreU(dequantized_val0, df, raw + i);
}
if (i < g_num) {
const size_t remaining = g_num - i;
const VI8 val0 =
hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
hn::StoreN(dequantized_val0, df, raw + i, remaining);
}
}
// Quantizes `num` floats from `raw` into `packed`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset
// within it, in number of elements. Scaling values are interleaved with
// int values to allow for easier unpacking.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void QuantizeGroup(DF df, const float* HWY_RESTRICT raw,
size_t num,
const PackedSpan<I8Stream>& packed,
size_t packed_ofs) {
using T = ScaleT;
const hn::Repartition<int32_t, DF> di32;
const hn::Half<hn::Repartition<int16_t, decltype(di32)>> di16;
const hn::Half<hn::Repartition<int8_t, decltype(di16)>> di8;
const size_t N = hn::Lanes(df);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
HWY_DASSERT(packed_ofs % kGroupSize == 0);
HWY_DASSERT(num % 2 * N == 0);
// Calculate min/max using SIMD
float min_val = hwy::HighestValue<float>();
float max_val = hwy::LowestValue<float>();
VF vmin = hn::Set(df, hwy::HighestValue<float>());
VF vmax = hn::Set(df, hwy::LowestValue<float>());
size_t j = 0;
for (; j + N <= num; j += N) {
const VF xi = hn::LoadU(df, raw + j);
vmin = hn::Min(vmin, xi);
vmax = hn::Max(vmax, xi);
}
min_val = hn::ReduceMin(df, vmin);
max_val = hn::ReduceMax(df, vmax);
for (; j < num; ++j) {
min_val = HWY_MIN(min_val, raw[j]);
max_val = HWY_MAX(max_val, raw[j]);
}
// Calculate range, scale and zeropoint
float x_range = max_val - min_val;
x_range = x_range == 0.0f ? 1.0f : x_range;
const float scale_f = 255.0f / x_range;
const float zeropoint_f = static_cast<float>(
static_cast<int32_t>(-scale_f * min_val - 128.0f)); // Correct casting
const T scale = hwy::ConvertScalarTo<T>(scale_f);
// inv_scale is used for all dequantization.
const T inv_scale = hwy::ConvertScalarTo<T>(1.0f / scale_f);
const T zeropoint = hwy::ConvertScalarTo<T>(zeropoint_f);
memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, sizeof(T));
memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), &zeropoint,
sizeof(T));
const size_t g_num = HWY_MIN(num, kGroupSize);
VF mul = hn::Set(df, hwy::ConvertScalarTo<float>(scale));
VF add = hn::Set(df, hwy::ConvertScalarTo<float>(zeropoint));
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
size_t i = 0;
for (; i + 2 * N <= g_num; i += 2 * N) {
const VI8 val0 = hn::DemoteTo(
di8,
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i + 0 * N), add))));
const VI8 val1 = hn::DemoteTo(
di8,
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i + 1 * N), add))));
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N);
hn::StoreU(val1, di8, &packed.ptr->i + current_offset + i + 1 * N);
}
size_t remaining = g_num - i;
HWY_DASSERT(remaining < 2 * N);
if (HWY_UNLIKELY(remaining == 0)) return;
if (remaining > N) {
const VI8 val0 = hn::DemoteTo(
di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i), add))));
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i);
const size_t remaining1 = remaining - N;
const VI8 val1 = hn::DemoteTo(
di8,
hn::DemoteTo(di16,
NearestInt(hn::MulAdd(
mul, hn::LoadN(df, raw + i + N, remaining1), add))));
hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N,
remaining1);
} else { // remaining <= N
const VI8 val0 = hn::DemoteTo(
di8, hn::DemoteTo(di16,
NearestInt(hn::MulAdd(
mul, hn::LoadN(df, raw + i, remaining), add))));
hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining);
}
}
// Encodes `num` floats from `raw` into `packed`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset
// within it, in number of elements. Scaling values are interleaved with
// int
// values to allow for easier unpacking.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT raw,
const size_t num,
const PackedSpan<I8Stream>& packed,
size_t packed_ofs) {
HWY_ASSERT(packed_ofs % kGroupSize == 0);
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
size_t current_offset = packed_ofs;
for (size_t g = 0; g < num_groups; ++g) {
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
QuantizeGroup(df, g_in, g_num, packed, current_offset);
current_offset += g_num;
}
}
// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec2(DBF dbf, const PackedSpan<const I8Stream>& packed,
const size_t packed_ofs, hn::Vec<DBF>& raw0,
hn::Vec<DBF>& raw1) {
const hn::Repartition<float, decltype(dbf)> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
HWY_ASSERT(packed_ofs % 2 * NF == 0);
VF raw0_f, raw1_f, raw2_f, raw3_f;
Dec2(df, packed, packed_ofs + 0 * 2 * NF, raw0_f, raw1_f);
Dec2(df, packed, packed_ofs + 1 * 2 * NF, raw2_f, raw3_f);
raw0 = hn::OrderedDemote2To(dbf, raw0_f, raw1_f);
raw1 = hn::OrderedDemote2To(dbf, raw2_f, raw3_f);
}
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec2(DF df, const PackedSpan<const I8Stream>& packed,
const size_t packed_ofs, hn::Vec<DF>& raw0,
hn::Vec<DF>& raw1) {
using T = ScaleT;
const hn::Rebind<int32_t, decltype(df)> di32;
const hn::Rebind<int16_t, decltype(di32)> di16;
const hn::Rebind<int8_t, decltype(di16)> di8;
const hn::Rebind<int8_t, decltype(df)> df8;
const size_t N = hn::Lanes(di8);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
T inv_scale, zeropoint;
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
sizeof(T));
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
&zeropoint, sizeof(T));
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
VF inv_scale_vec = hn::Set(df, inv_scale_f);
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + 0 * N);
const VI8 val1 = hn::LoadU(di8, &packed.ptr->i + current_offset + 1 * N);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
const VF val1_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
raw0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
raw1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
}
template <class D, typename Raw = hn::TFromD<D>>
static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
Raw* HWY_RESTRICT raw, size_t num) {
if (num == 0) return;
const size_t N = hn::Lanes(d);
const size_t padded_num = hwy::RoundUpTo(num, N);
if (padded_num > num) {
hwy::ZeroBytes(raw + num, (padded_num - num) * sizeof(Raw));
}
size_t current_packed_ofs = packed_ofs;
Raw* HWY_RESTRICT current_raw = raw;
size_t num_to_decompress = num;
if (size_t within_group = current_packed_ofs % kGroupSize;
within_group != 0) {
const size_t remaining_in_group = kGroupSize - within_group;
const size_t num_in_first_group =
HWY_MIN(num_to_decompress, remaining_in_group);
DequantizeGroup(d, packed, current_packed_ofs, current_raw,
num_in_first_group);
current_packed_ofs += num_in_first_group;
current_raw += num_in_first_group;
num_to_decompress -= num_in_first_group;
}
if (num_to_decompress == 0) return;
HWY_DASSERT(current_packed_ofs % kGroupSize == 0);
const size_t num_full_groups = num_to_decompress / kGroupSize;
for (size_t g = 0; g < num_full_groups; ++g) {
DequantizeGroup(d, packed, current_packed_ofs, current_raw, kGroupSize);
current_packed_ofs += kGroupSize;
current_raw += kGroupSize;
}
const size_t remaining = num_to_decompress % kGroupSize;
if (remaining != 0) {
DequantizeGroup(d, packed, current_packed_ofs, current_raw, remaining);
}
}
}; // IntCodec
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_

494
compression/int_test.cc Normal file
View File

@ -0,0 +1,494 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
#endif
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "compression/int_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/int-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
static constexpr size_t kGroupSize = I8Stream::kGroupSize;
static constexpr float kTolerance = 50000.0f;
// Can encode and decode sub-regions.
// Quantizes and de-quantizes a single (potentially partial) group to check
// that the quantizer is working correctly.
struct TestQuantize {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t total = kGroupSize / 2; // already padded
const hn::ScalableTag<float> df;
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(total);
auto dec3 = hwy::AllocateAligned<T>(total);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && dec3 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
IntCodec::QuantizeGroup(df, in.get(), total, int_span, 0);
IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec1.get(), total);
const float epsilon =
hwy::ConvertScalarTo<float>(hwy::Epsilon<hwy::bfloat16_t>());
const float tolerance = kTolerance * epsilon;
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec1[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec1[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
}
}
// Check that ::Enc works correctly as well.
IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec2.get(), total);
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec2[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
}
}
// Check that ::DecompressAndZeroPad works correctly for one group as well.
IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec3.get(),
total);
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec3[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec3[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
HWY_ASSERT(false);
}
}
}
};
void TestQuantizeBF16() { hn::ForGEVectors<128, TestQuantize>()(BF16()); }
void TestQuantizeF32() { hn::ForGEVectors<128, TestQuantize>()(float()); }
// Can encode and decode sub-regions.
// Quantizes and de-quantizes multiple (potentially partial) groups to check
// that DecompressAndZeroPad is working correctly.
struct TestMultiGroup {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = kGroupSize * 2 + kGroupSize / 4; // already padded
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(total);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec1 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
const float epsilon =
hwy::ConvertScalarTo<float>(hwy::Epsilon<hwy::bfloat16_t>());
const float tolerance = kTolerance * epsilon;
// Check that ::DecompressAndZeroPad works correctly for one group as well.
IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec2.get(),
total);
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec2[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
HWY_ASSERT(false);
}
}
}
};
void TestMultiGroupBF16() { hn::ForGEVectors<128, TestMultiGroup>()(BF16()); }
void TestMultiGroupF32() { hn::ForGEVectors<128, TestMultiGroup>()(float()); }
// Can encode and decode sub-regions.
struct TestOffset {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = 10 * kGroupSize; // already padded
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Encode + decode everything
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(),
total);
MaybeCheckInitialized(dec1.get(), total * sizeof(T));
// Overwrite middle with first inputs
const size_t offset = 5 * kGroupSize;
(void)IntCodec::Enc(df, in.get(), kMidLen, int_span, offset);
// Decoded middle now matches previously decoded first
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, dec2.get(),
kMidLen);
MaybeCheckInitialized(dec2.get(), kMidLen * sizeof(T));
for (size_t i = 0; i < kMidLen; ++i) {
HWY_ASSERT(dec1[i] == dec2[i]);
}
}
};
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
// Can encode and decode sub-regions.
struct TestUnalignedOffset {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = 10 * kGroupSize; // already padded
const int num_unaligned_offsets = 4;
const std::array<size_t, num_unaligned_offsets> unaligned_offsets = {
4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100};
const std::array<size_t, num_unaligned_offsets> num = {4, 16, 32, 64};
for (int i = 0; i < num_unaligned_offsets; ++i) {
const size_t unaligned_offset = unaligned_offsets[i];
const size_t num_decompressed = num[i];
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto i8_stream =
hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
HWY_ASSERT(in && dec1 && dec2 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// // Encode + decode everything
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(),
total);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), unaligned_offset,
dec2.get(), num_decompressed);
for (size_t i = 0; i < num_decompressed; ++i) {
T expected = hwy::ConvertScalarTo<T>(dec1[unaligned_offset + i]);
T actual = hwy::ConvertScalarTo<T>(dec2[i]);
HWY_ASSERT_EQ(expected, actual);
}
}
}
};
void TestUnalignedOffsetBF16() {
hn::ForGEVectors<128, TestUnalignedOffset>()(BF16());
}
void TestUnalignedOffsetF32() {
hn::ForGEVectors<128, TestUnalignedOffset>()(float());
}
// Can encode and decode sub-regions.
// Uses Dec2 to decode all elements in the packed buffer, then
// compares against DecompressAndZeroPad.
struct TestDec2 {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
// incl. partial group to test partial group handling
const size_t total = kGroupSize * 10 + kGroupSize / 2;
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec0 = hwy::AllocateAligned<T>(total);
auto dec1 = hwy::AllocateAligned<T>(total);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec0 && dec1 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Non-interleaved encode + decode for comparison
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec0.get(),
total);
// Encode + decode everything
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
using V = hn::Vec<decltype(d)>;
const size_t N = Lanes(d);
for (size_t i = 0; i < total; i += 2 * N) {
V f0, f1;
IntCodec::Dec2(d, MakeConst(int_span), i, f0, f1);
hn::StoreU(f0, d, dec1.get() + i + 0 * N);
hn::StoreU(f1, d, dec1.get() + i + 1 * N);
}
for (size_t i = 0; i < total; ++i) {
if (dec0[i] != dec1[i]) {
fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i,
hwy::ConvertScalarTo<float>(dec0[i]), i,
hwy::ConvertScalarTo<float>(dec1[i]));
}
HWY_ASSERT(dec0[i] == dec1[i]);
}
}
};
void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); }
void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); }
// Tests that DecompressAndZeroPad fully populates the output array.
// This is intended to catch uninitialized value errors.
struct TestDequantizeAndZeroPad {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::ScalableTag<float> df;
constexpr size_t kSize = 4096;
auto in = hwy::AllocateAligned<float>(kSize);
auto actual_dec = hwy::AllocateAligned<T>(kSize);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kSize));
HWY_ASSERT(in && actual_dec && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), kSize);
// Fill with a known pattern.
for (size_t i = 0; i < kSize; ++i) {
in[i] = static_cast<float>(i) - 128.0f;
}
IntCodec::Enc(df, in.get(), kSize, int_span, 0);
// Initialize with a sentinel value to detect if it's overwritten.
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
std::fill(actual_dec.get(), actual_dec.get() + kSize, sentinel);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, actual_dec.get(),
kSize);
MaybeCheckInitialized(actual_dec.get(), kSize * sizeof(T));
// Check that all sentinels were overwritten.
for (size_t i = 0; i < kSize; ++i) {
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
hwy::ConvertScalarTo<float>(sentinel))
<< " at index " << i;
}
}
};
void TestAllDequantizeAndZeroPad() {
hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(BF16());
hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(float());
}
// Tests that DecompressAndZeroPad works correctly for small and unaligned
// inputs. This is intended to catch uninitialized value errors in remainder
// handling.
struct TestSmallDequantize {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::ScalableTag<float> df;
constexpr size_t kGroupSize = I8Stream::kGroupSize;
constexpr size_t kMaxNum = kGroupSize * 3;
auto in = hwy::AllocateAligned<float>(kMaxNum);
auto actual_dec = hwy::AllocateAligned<T>(kMaxNum);
auto i8_stream =
hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kMaxNum));
HWY_ASSERT(in && actual_dec && i8_stream);
const auto int_span =
MakeSpan(i8_stream.get(), I8Stream::PackedEnd(kMaxNum));
// Fill with a known pattern.
for (size_t i = 0; i < kMaxNum; ++i) {
in[i] = static_cast<float>(i) - 128.0f;
}
IntCodec::Enc(df, in.get(), kMaxNum, int_span, 0);
for (size_t num = 1; num < kGroupSize * 2; ++num) {
for (size_t offset = 0; offset < kGroupSize; offset += 16) {
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
std::fill(actual_dec.get(), actual_dec.get() + num, sentinel);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset,
actual_dec.get(), num);
MaybeCheckInitialized(actual_dec.get(), num);
// Check that all sentinels were overwritten.
for (size_t i = 0; i < num; ++i) {
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
hwy::ConvertScalarTo<float>(sentinel))
<< " at index " << i << " for num=" << num
<< " offset=" << offset;
}
}
}
}
};
void TestAllSmallDequantize() {
hn::ForGEVectors<128, TestSmallDequantize>()(BF16());
hn::ForGEVectors<128, TestSmallDequantize>()(float());
}
// Tests that DecompressAndZeroPad works correctly for a specific failing input.
struct TestSpecificDequantize {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::ScalableTag<float> df;
constexpr size_t kSize = 737280;
auto in = hwy::AllocateAligned<float>(kSize);
auto actual_dec = hwy::AllocateAligned<T>(kSize);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kSize));
HWY_ASSERT(in && actual_dec && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), kSize);
// Fill with a known pattern.
for (size_t i = 0; i < kSize; ++i) {
in[i] = static_cast<float>(i) - 128.0f;
}
IntCodec::Enc(df, in.get(), kSize, int_span, 0);
const size_t num = 64;
const size_t offset = 392704;
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
std::fill(actual_dec.get(), actual_dec.get() + num, sentinel);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset,
actual_dec.get(), num);
// Check that all sentinels were overwritten.
for (size_t i = 0; i < num; ++i) {
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
hwy::ConvertScalarTo<float>(sentinel))
<< " at index " << i << " for num=" << num << " offset=" << offset;
}
}
};
void TestAllSpecificDequantize() {
hn::ForGEVectors<128, TestSpecificDequantize>()(BF16());
hn::ForGEVectors<128, TestSpecificDequantize>()(float());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(IntTest);
HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestDec2BF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestDec2F32);
HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestAllDequantizeAndZeroPad);
HWY_EXPORT_AND_TEST_P(IntTest, TestAllSmallDequantize);
HWY_EXPORT_AND_TEST_P(IntTest, TestAllSpecificDequantize);
HWY_AFTER_TEST();
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -113,6 +113,9 @@ class SbsWriterImpl : public ISbsWriter {
case Type::kF32:
InsertT<float>(name, weights, tensor_info);
break;
case Type::kI8:
InsertT<I8Stream>(name, weights, tensor_info);
break;
default:
HWY_ABORT("Unsupported destination (compressed) type %s",
TypeName(type));

View File

@ -90,6 +90,13 @@ class CompressionTest(absltest.TestCase):
info_256,
)
writer.insert(
"tensor_i8",
np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32),
configs.Type.kI8,
info_256,
)
config = configs.ModelConfig(
configs.Model.GEMMA2_2B,
configs.Type.kSFP,
@ -140,6 +147,11 @@ class CompressionTest(absltest.TestCase):
self.assertEqual(mat.type, configs.Type.kF32)
self.assertAlmostEqual(mat.scale, 1.0)
mat = reader.find_mat("tensor_i8")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kI8)
self.assertAlmostEqual(mat.scale, 1.0)
if __name__ == "__main__":
absltest.main()

View File

@ -89,6 +89,26 @@ struct SfpStream {
};
#pragma pack(pop)
#pragma pack(push, 1)
struct I8Stream {
static constexpr size_t kGroupSize = 128;
using ScaleT = hwy::bfloat16_t;
// Returns number of I8Stream to allocate for the stream, which matches its
// size in bytes.
// TODO: should support other types beyond hwy::float32_t for scale and
// zero-point.
static constexpr size_t PackedEnd(size_t capacity) {
const size_t num_groups = hwy::DivCeil(capacity, kGroupSize);
return (sizeof(ScaleT) * num_groups) + // scale
(sizeof(ScaleT) * num_groups) + // zero-point
capacity; // 1 value per byte
}
int8_t i;
};
#pragma pack(pop)
// Non-uniform quantization: a compressed representation of f32 inputs that
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
// two vectors (for `Decompress2`), and decoding to bf16/f32.
@ -187,18 +207,23 @@ constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
}
template <typename Packed>
constexpr bool IsI8Stream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
}
template <typename Packed>
constexpr bool SupportsPointerArithmetic() {
return !IsNuqStream<Packed>();
return !IsNuqStream<Packed>() && !IsI8Stream<Packed>();
}
// Tensor types for loading weights. Not all of these are supported weight
// types, some are only used for `Activations`.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 };
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 };
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "u32", "u64"};
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
@ -210,6 +235,7 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(double),
8 * sizeof(uint32_t),
8 * sizeof(uint64_t),
8 * sizeof(I8Stream),
};
static inline bool EnumValid(Type type) {
@ -234,6 +260,8 @@ Type TypeEnum() {
return Type::kU32;
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
return Type::kU64;
} else if constexpr (hwy::IsSame<Packed, I8Stream>()) {
return Type::kI8;
} else {
HWY_DASSERT(false);
return Type::kUnknown;
@ -254,7 +282,9 @@ const char* TypeName() {
template <typename Packed>
constexpr bool IsCompressed() {
return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>();
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>() ||
hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>() ||
hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
}
// Returns the number of `MatT` elements required to store `capacity` values,
@ -265,6 +295,8 @@ template <typename Packed>
constexpr size_t CompressedArrayElements(size_t capacity) {
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
return NuqStream::PackedEnd(capacity);
} else if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>()) {
return I8Stream::PackedEnd(capacity);
} else {
return capacity;
}

View File

@ -143,8 +143,8 @@ void SingleDotSoftmaxWeightedSum(
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), q, layer.layer_config.qkv_dim,
p, worker);
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, p, worker);
});
}
@ -315,8 +315,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim,
env.ctx.profiler, worker);
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32,
qkv_dim, env.ctx.profiler, worker);
});
}

View File

@ -114,7 +114,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), q_row,
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker);
});
}

View File

@ -59,7 +59,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
return;
};
// Has multiplier, Gelu(c1) * c2.
Decompress1AndCompressInplace(DF(), c1, count, c2,
Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0,
[](DF df, VF v1, VF v2) HWY_ATTR -> VF {
return hn::Mul(v2, Gelu(df, v1));
});
@ -101,8 +101,9 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1,
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
Decompress1AndCompressInplace(
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
[](DF df, VF v1, VF v2)
HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); });
/*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF {
return hn::Mul(v2, Gelu(df, v1));
});
}
}

View File

@ -112,6 +112,8 @@ class TypePrefix {
return Type::kSFP;
case '2':
return Type::kNUQ;
case 'I':
return Type::kI8;
default:
// The other types were not written to pre-2025 files, hence no need to
// encode and check for them here.

View File

@ -46,7 +46,7 @@ struct TensorInfo {
// The highest permissible compression for this tensor. The default is
// kNUQ, which provides maximum compression. Other values such as kBF16
// or kF32 can be used to limit the compression to a specific type.
Type min_size = Type::kNUQ;
Type min_size = Type::kI8;
// Whether to apply scaled softplus to the data.
bool scaled_softplus = false;
// Whether the columns or the rows take any extra dimensions.

View File

@ -332,8 +332,8 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
// Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0),
vit_model_dim, env.ctx.profiler,
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
activations.x.Row(0), vit_model_dim, env.ctx.profiler,
hwy::Profiler::GlobalIdx());
});
}

View File

@ -147,16 +147,223 @@ void LayerWeightsPtrs::SplitAttW1() {
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
static void HWY_MAYBE_UNUSED InitAttWeightsI8(
const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w,
MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8);
att_weights.SetType(Type::kI8);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked);
}
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0,
attn_vec_einsum_w_tmp.get(),
model_dim * heads * qkv_dim);
for (size_t m = 0; m < model_dim; ++m) {
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(float));
}
}
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
work, att_weights.Span(),
/*packed_ofs=*/0, pool);
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& gating_einsum_w,
MatPtrT<I8Stream>& gating_einsum_w1,
MatPtrT<I8Stream>& gating_einsum_w2,
std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file.
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
// Done if we already read split tensors.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
// Nothing to do if w is not present.
if (!gating_einsum_w.HasPtr()) return;
HWY_ASSERT(gating_einsum_w.GetType() == Type::kI8);
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t model_dim = gating_einsum_w.Cols();
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Cols() == model_dim);
HWY_ASSERT(gating_einsum_w2.Cols() == model_dim);
gating_einsum_w1.SetType(Type::kI8);
gating_einsum_w2.SetType(Type::kI8);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w1, allocator,
MatPadding::kPacked);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w2, allocator,
MatPadding::kPacked);
}
const size_t total_size = gating_einsum_w.Rows() * gating_einsum_w.Cols();
hwy::AlignedFreeUniquePtr<float[]> w_tmp =
hwy::AllocateAligned<float>(total_size);
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
HWY_NAMESPACE::DecompressAndZeroPad(df, gating_einsum_w.Span(), 0,
w_tmp.get(), total_size);
const size_t split_size = ff_hidden_dim * model_dim;
float* w1_tmp = w_tmp.get();
float* w2_tmp = w_tmp.get() + split_size;
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0,
pool);
HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0,
pool);
gating_einsum_w1.SetScale(1.0f);
gating_einsum_w2.SetScale(1.0f);
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
}
static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& qkv_einsum_w,
MatPtrT<I8Stream>& qkv_einsum_w1,
MatPtrT<I8Stream>& qkv_einsum_w2,
std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
// Nothing to do if w is not present.
if (!qkv_einsum_w.HasPtr()) return;
HWY_ASSERT(qkv_einsum_w.GetType() == Type::kI8);
const size_t model_dim = qkv_einsum_w.Cols();
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
HWY_ASSERT(qkv_einsum_w1.Cols() == model_dim);
HWY_ASSERT(qkv_einsum_w2.Cols() == model_dim);
qkv_einsum_w1.SetType(Type::kI8);
qkv_einsum_w2.SetType(Type::kI8);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w1, allocator,
MatPadding::kPacked);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w2, allocator,
MatPadding::kPacked);
}
const size_t total_size = qkv_einsum_w.Rows() * qkv_einsum_w.Cols();
hwy::AlignedFreeUniquePtr<float[]> w_tmp =
hwy::AllocateAligned<float>(total_size);
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
HWY_NAMESPACE::DecompressAndZeroPad(df, qkv_einsum_w.Span(), 0, w_tmp.get(),
total_size);
const size_t w1_size = w1_rows * model_dim;
const size_t w2_size = w2_rows * model_dim;
float* w1_tmp = w_tmp.get();
float* w2_tmp = w_tmp.get() + w1_size;
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool);
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool);
qkv_einsum_w1.SetScale(1.0f);
qkv_einsum_w2.SetScale(1.0f);
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// TODO(janwas): handle NUQ
if (attn_vec_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w);
MatPtrT<I8Stream> att_weights_i8(att_weights);
InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8,
mat_owners, allocator);
attn_vec_einsum_w = attn_vec_einsum_w_i8;
att_weights = att_weights_i8;
} else {
InitAttWeights(mat_owners, allocator);
}
if (gating_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> gating_einsum_w_i8(gating_einsum_w);
MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1);
MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2);
SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8,
gating_einsum_w2_i8, mat_owners, allocator);
gating_einsum_w = gating_einsum_w_i8;
gating_einsum_w1 = gating_einsum_w1_i8;
gating_einsum_w2 = gating_einsum_w2_i8;
} else {
SplitW1();
}
if (qkv_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> qkv_einsum_w_i8(qkv_einsum_w);
MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1);
MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2);
SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8,
qkv_einsum_w2_i8, mat_owners, allocator);
qkv_einsum_w = qkv_einsum_w_i8;
qkv_einsum_w1 = qkv_einsum_w1_i8;
qkv_einsum_w2 = qkv_einsum_w2_i8;
} else {
SplitAttW1();
}
}
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
@ -427,8 +634,6 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
static std::vector<IOBatch> MakeBatches(
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches");
// Batches must be contiguous but blobs are padded, hence at least one
// batch per tensor, and more when tensor rows exceed the batch size.
std::vector<IOBatch> batches;
batches.reserve(tensors.size());
@ -439,11 +644,21 @@ static std::vector<IOBatch> MakeBatches(
HWY_ASSERT(range.End() <= file_bytes);
batches.emplace_back(offset, range.key_idx);
if (mat.IsPacked()) {
HWY_ASSERT(range.bytes == mat.PackedBytes());
if (!batches.back().Add(mat.Packed(), range.bytes)) {
// This should not happen if tensors are < 2GB.
// If it does, we need to chunk. For now, let's assume it doesn't.
HWY_ABORT("Packed tensor too large for a single IO batch.");
}
offset += range.bytes;
} else {
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
uint8_t* row_bytes = mat.RowBytes(0);
for (size_t r = 0; r < mat.Rows(); ++r) {
if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
if (!batches.back().Add(row_bytes,
file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, range.key_idx);
// Adding to an empty batch is always successful.
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
@ -451,6 +666,7 @@ static std::vector<IOBatch> MakeBatches(
offset += file_bytes_per_row;
row_bytes += mem_stride_bytes;
}
}
HWY_ASSERT(offset == range.End());
}

View File

@ -50,6 +50,7 @@
GEMMA_MATMUL_FOR_B(float) \
GEMMA_MATMUL_FOR_B(NuqStream) \
GEMMA_MATMUL_FOR_B(SfpStream) \
GEMMA_MATMUL_FOR_B(I8Stream) \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

29
ops/matmul_static_i8.cc Normal file
View File

@ -0,0 +1,29 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_static_i8.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_MATMUL_TB I8Stream
#include "ops/matmul_static-inl.h"

View File

@ -220,6 +220,7 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const WT* HWY_RESTRICT weight,
const size_t w_ofs,
OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p,
const size_t worker) {
@ -232,7 +233,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker));
const VF* HWY_RESTRICT pmul = &mul;
Decompress2AndCompressTo(DF(), out, size, x, weight,
Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs,
[pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF {
const VF m = hn::Mul(*pmul, vx);
// (1+weight) * m = m + weight*m = one FMA.
@ -242,13 +243,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
XT* HWY_RESTRICT inout,
const size_t size,
hwy::Profiler& p,
const size_t worker) {
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout,
const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
@ -256,7 +254,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker));
const VF* HWY_RESTRICT pmul = &mul;
Decompress1AndCompressInplace(DF(), inout, size, weight,
Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs,
[pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF {
const VF m = hn::Mul(*pmul, vx);
// (1+weight) * m = m + weight*m = one FMA.
@ -489,7 +487,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
Decompress1AndCompressInplace(DF(), out, size, x,
Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0,
[&](DF /*df*/, VF out, VF x)
HWY_ATTR -> VF { return hn::Add(x, out); });
}
@ -507,8 +505,8 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
cluster_idx, [&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
out.Row(token_idx), activations.Cols(), ctx.profiler,
worker);
/*w_ofs=*/0, out.Row(token_idx), activations.Cols(),
ctx.profiler, worker);
});
});
}
@ -522,7 +520,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
[&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(),
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
inout.Row(token_idx), inout.Cols(),
ctx.profiler, worker);
});
@ -604,7 +602,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
const VF vc = hn::Set(DF(), c);
const VF* HWY_RESTRICT pc = &vc;
Decompress1AndCompressInplace(DF(), out, size, x,
Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0,
[&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF {
return hn::MulAdd(x, *pc, out);
});

View File

@ -558,7 +558,8 @@ struct TestRMSNorm {
ScalarRMSNorm(vec, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
@ -593,7 +594,7 @@ struct TestRMSNormInplace {
ScalarRMSNorm(expected, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(),
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0);
for (size_t i = 0; i < kSize; i++) {

View File

@ -53,7 +53,11 @@ PYBIND11_MODULE(configs, py_module) {
.value("kF32", Type::kF32)
.value("kBF16", Type::kBF16)
.value("kSFP", Type::kSFP)
.value("kNUQ", Type::kNUQ);
.value("kNUQ", Type::kNUQ)
.value("kF64", Type::kF64)
.value("kU32", Type::kU32)
.value("kU64", Type::kU64)
.value("kI8", Type::kI8);
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
.value("kGemma", LayerAttentionType::kGemma)

View File

@ -59,6 +59,25 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
#endif
}
static inline void MaybePrintInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
__msan_print_shadow(ptr, size);
#else
(void)ptr;
(void)size;
#endif
}
static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
return __msan_test_shadow(ptr, size);
#else
(void)ptr;
(void)size;
return 0;
#endif
}
// Shared between gemma.h and ops-inl.h.
#pragma pack(push, 1)
struct TokenAndProb {

View File

@ -80,11 +80,13 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator,
MatPadding padding) {
const bool is_nuq = mat.GetType() == Type::kNUQ;
if (is_nuq) padding = MatPadding::kPacked;
const bool is_compressed_and_packed =
mat.GetType() == Type::kNUQ || mat.GetType() == Type::kI8;
if (is_compressed_and_packed) padding = MatPadding::kPacked;
const size_t stride =
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
const size_t num =
is_compressed_and_packed ? mat.PackedBytes() : mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
// might not be enough, hence add extra. `MatT` is at least one byte, which
// is half of BF16, hence adding `VectorBytes` *elements* is enough.

View File

@ -240,6 +240,8 @@ class MatPtr : public IFields {
// `CompressedArrayElements` is a wrapper function that has the same
// effect, but that requires a template argument, not `type`.
num_elements = NuqStream::PackedEnd(num_elements);
} else if (type == Type::kI8) {
num_elements = I8Stream::PackedEnd(num_elements);
}
return num_elements;
}
@ -324,7 +326,8 @@ class MatPtrT : public MatPtr {
}
PackedSpan<const MatT> PaddedSpan() const {
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
const size_t num = IsPacked() ? num_elements_ : Rows() * Stride();
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num);
}
// For `compress-inl.h` functions, which assume contiguous streams and thus
@ -379,6 +382,9 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
} else if (base->GetType() == Type::kSFP) {
const MatPtrT<SfpStream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kI8) {
const MatPtrT<I8Stream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
@ -410,6 +416,10 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const MatPtrT<SfpStream> mat1(*base1);
const MatPtrT<SfpStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kI8) {
const MatPtrT<I8Stream> mat1(*base1);
const MatPtrT<I8Stream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
}