diff --git a/BUILD.bazel b/BUILD.bazel index f5fad45..ffd5435 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -349,6 +349,7 @@ cc_library( "ops/matmul_static_f32.cc", "ops/matmul_static_nuq.cc", "ops/matmul_static_sfp.cc", + "ops/matmul_static_i8.cc", ], hdrs = [ "ops/matmul_static.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 983d643..3eb2046 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ set(SOURCES compression/compress.h compression/nuq-inl.h compression/sfp-inl.h + compression/int-inl.h compression/types.h compression/test_util-inl.h evals/benchmark_helper.cc @@ -109,6 +110,7 @@ set(SOURCES ops/matmul_static_f32.cc ops/matmul_static_nuq.cc ops/matmul_static_sfp.cc + ops/matmul_static_i8.cc ops/matmul-inl.h ops/matmul.cc ops/matmul.h diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index a72db0b..c7232e6 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -80,6 +80,37 @@ cc_library( ], ) +cc_library( + name = "int", + textual_hdrs = ["int-inl.h"], + deps = [ + ":types", + "//:basics", + "@highway//:hwy", + ], +) + +cc_test( + name = "int_test", + size = "small", + timeout = "long", + srcs = ["int_test.cc"], + features = ["fully_static_link"], + linkstatic = True, + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":distortion", + ":int", + "@googletest//:gtest_main", # buildcleaner: keep + "//:test_util", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", + ], +) + cc_library( name = "test_util", textual_hdrs = [ @@ -144,6 +175,7 @@ cc_library( textual_hdrs = ["compress-inl.h"], deps = [ ":distortion", + ":int", ":nuq", ":sfp", "//:basics", @@ -182,6 +214,7 @@ cc_library( name = "analyze", textual_hdrs = ["analyze.h"], deps = [ + ":int", ":nuq", ":sfp", ":types", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 18d8e35..35f0433 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -47,6 +47,7 @@ #include "hwy/highway.h" // After highway.h +#include "compression/int-inl.h" #include "compression/nuq-inl.h" #include "compression/sfp-inl.h" @@ -416,6 +417,34 @@ struct CompressTraits { } }; +// Integer quantization. +template <> +struct CompressTraits { + using Packed = I8Stream; + + template + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, + size_t num, CompressPerThread& tls, + const PackedSpan& packed, + const size_t packed_ofs) { + IntCodec::Enc(df, raw, num, packed, packed_ofs); + } + + template // Caller checks this is f32 or bf16 + static HWY_INLINE void Load2(D d, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1); + } + + template + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, const size_t packed_ofs, + Raw* raw, const size_t num) { + IntCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num); + } +}; + // Nonuniform quantization, 4.5 bits per element, two separate streams. template <> struct CompressTraits { @@ -737,9 +766,10 @@ template HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1, + const size_t p1_ofs, Func&& func) { const auto packed_inout = MakeSpan(inout, num); - const auto packed1 = MakeSpan(p1, num); + const auto packed1 = MakeSpan(p1, p1_ofs + num); using VF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); @@ -749,7 +779,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, VF v0, v1; Decompress2(df, packed_inout, i, v0, v1); VF v10, v11; - Decompress2(df, packed1, i, v10, v11); + Decompress2(df, packed1, p1_ofs + i, v10, v11); const VF out0 = func(df, v0, v10); const VF out1 = func(df, v1, v11); Compress2(df, out0, out1, packed_inout, i); @@ -765,7 +795,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, hn::Store(hn::Zero(df), df, buf_inout + NF); hn::Store(hn::Zero(df), df, buf1 + NF); DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining); - DecompressAndZeroPad(df, packed1, i, buf1, remaining); + DecompressAndZeroPad(df, packed1, p1_ofs + i, buf1, remaining); const VF v0 = hn::Load(df, buf_inout); const VF v1 = hn::Load(df, buf_inout + NF); const VF v10 = hn::Load(df, buf1); @@ -827,10 +857,10 @@ template HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, const T1* HWY_RESTRICT p1, const T2* HWY_RESTRICT p2, - Func&& func) { + const size_t p2_ofs, Func&& func) { const auto packed_out = MakeSpan(out, num); const auto packed1 = MakeSpan(p1, num); - const auto packed2 = MakeSpan(p2, num); + const auto packed2 = MakeSpan(p2, p2_ofs + num); using VF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); @@ -839,7 +869,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, for (; i <= num - 2 * NF; i += 2 * NF) { VF v10, v11, v20, v21; Decompress2(df, packed1, i, v10, v11); - Decompress2(df, packed2, i, v20, v21); + Decompress2(df, packed2, p2_ofs + i, v20, v21); const VF out0 = func(df, v10, v20); const VF out1 = func(df, v11, v21); Compress2(df, out0, out1, packed_out, i); @@ -856,7 +886,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, hn::Store(hn::Zero(df), df, buf1 + NF); hn::Store(hn::Zero(df), df, buf2 + NF); DecompressAndZeroPad(df, packed1, i, buf1, remaining); - DecompressAndZeroPad(df, packed2, i, buf2, remaining); + DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining); const VF v10 = hn::Load(df, buf1); const VF v11 = hn::Load(df, buf1 + NF); const VF v20 = hn::Load(df, buf2); diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 5455b1d..2ee7f63 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -243,7 +243,7 @@ class TestDecompressAndCompress { // Uses `out` so as not to overwrite `p`. Decompress1AndCompressInplace( - df, out.get(), num, p1.get(), + df, out.get(), num, p1.get(), /*p1_ofs=*/0, [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); @@ -251,9 +251,9 @@ class TestDecompressAndCompress { [](DF, VF v) HWY_ATTR -> VF { return v; }); HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num); - Decompress2AndCompressTo(df, out.get(), num, p.get(), p1.get(), - [](DF, VF v, VF v1) - HWY_ATTR -> VF { return hn::Add(v, v1); }); + Decompress2AndCompressTo( + df, out.get(), num, p.get(), p1.get(), /*p2_ofs=*/0, + [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); Decompress3AndCompressTo( diff --git a/compression/int-inl.h b/compression/int-inl.h new file mode 100644 index 0000000..969ec6d --- /dev/null +++ b/compression/int-inl.h @@ -0,0 +1,474 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ + +#include +#include +#include + +#include +#include + +#include "compression/types.h" +#include "util/basics.h" +#include "hwy/base.h" +#include "hwy/print-inl.h" + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ + +// Actual per-target include guard. +#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// Encode/decode functions. +class IntCodec { + using ScaleT = hwy::bfloat16_t; + static constexpr size_t kGroupSize = I8Stream::kGroupSize; + + // Offset (in bytes) of a group's start for packed_ofs (in elements) within a + // set of groups. + static constexpr size_t GroupByteOffset(size_t packed_ofs) { + const size_t kBytesPerGroup = (2 * sizeof(ScaleT)) + kGroupSize; + return (packed_ofs / kGroupSize) * kBytesPerGroup; + } + + public: + template + static HWY_INLINE void DequantizeGroup( + DBF dbf, const PackedSpan& packed, size_t packed_ofs, + hwy::bfloat16_t* HWY_RESTRICT raw, size_t num) { + using T = ScaleT; + const hn::ScalableTag df; + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + const hn::Twice> dbf16; + + const size_t N = hn::Lanes(di8); + const size_t N16 = hn::Lanes(dbf16); + using VI8 = hn::Vec; + using VF = hn::Vec; + + T inv_scale, zeropoint; + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, + sizeof(T)); + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), + &zeropoint, sizeof(T)); + + float inv_scale_f = hwy::ConvertScalarTo(inv_scale); + float zeropoint_f = hwy::ConvertScalarTo(zeropoint); + + VF inv_scale_vec = hn::Set(df, inv_scale_f); + VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f)); + + // Then iterate over remainder of packed, extracting num / N vectors and + // inserting into raw. + const size_t g_num = HWY_MIN(num, kGroupSize); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + size_t i = 0; + for (i = 0; i + 4 * N <= g_num; i += 4 * N) { + const VI8 val0 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N); + const VI8 val1 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N); + const VI8 val2 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 2 * N); + const VI8 val3 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 3 * N); + + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + const VF val1_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1))); + const VF val2_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val2))); + const VF val3_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val3))); + + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec); + VF dequantized_val2 = hn::MulAdd(inv_scale_vec, val2_f, zeroscale_vec); + VF dequantized_val3 = hn::MulAdd(inv_scale_vec, val3_f, zeroscale_vec); + + hn::StoreU( + hn::OrderedDemote2To(dbf16, dequantized_val0, dequantized_val1), + dbf16, raw + i + 0 * N16); + hn::StoreU( + hn::OrderedDemote2To(dbf16, dequantized_val2, dequantized_val3), + dbf16, raw + i + 1 * N16); + } + for (; i + N <= g_num; i += N) { + const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + const hn::Rebind dbf_half; + hn::StoreU(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i); + } + if (i < g_num) { + const size_t remaining = g_num - i; + const VI8 val0 = + hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + const hn::Rebind dbf_half; + hn::StoreN(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i, + remaining); + } + } + + // Dequantizes `num` floats from `packed` into `raw`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset + // within it, in number of elements. Scaling values are interleaved with int + // values to allow for easier unpacking. + template + static HWY_INLINE void DequantizeGroup( + DF df, const PackedSpan& packed, size_t packed_ofs, + float* HWY_RESTRICT raw, size_t num) { + using T = ScaleT; + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + const hn::Rebind df8; + + const size_t N = hn::Lanes(di8); + const size_t N32 = hn::Lanes(df); + using VI8 = hn::Vec; + using VF = hn::Vec; + + // HWY_ASSERT(num % 2 * N == 0); + + // Load scale and zero point from the beginning - ensure correct pointer + // offset. + T inv_scale, zeropoint; + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, + sizeof(T)); + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), + &zeropoint, sizeof(T)); + + float inv_scale_f = hwy::ConvertScalarTo(inv_scale); + float zeropoint_f = hwy::ConvertScalarTo(zeropoint); + + VF inv_scale_vec = hn::Set(df, inv_scale_f); + VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f)); + + // Then iterate over remainder of packed, extracting num / N vectors and + // inserting into raw. + const size_t g_num = HWY_MIN(num, kGroupSize); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + + size_t i = 0; + for (; i + 2 * N <= g_num; i += 2 * N) { + const VI8 val0 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N); + const VI8 val1 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N); + + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + const VF val1_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1))); + + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec); + + hn::StoreU(dequantized_val0, df, raw + i + 0 * N32); + hn::StoreU(dequantized_val1, df, raw + i + 1 * N32); + } + for (; i + N <= g_num; i += N) { + const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + hn::StoreU(dequantized_val0, df, raw + i); + } + if (i < g_num) { + const size_t remaining = g_num - i; + const VI8 val0 = + hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + hn::StoreN(dequantized_val0, df, raw + i, remaining); + } + } + + // Quantizes `num` floats from `raw` into `packed`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset + // within it, in number of elements. Scaling values are interleaved with + // int values to allow for easier unpacking. + template + static HWY_INLINE void QuantizeGroup(DF df, const float* HWY_RESTRICT raw, + size_t num, + const PackedSpan& packed, + size_t packed_ofs) { + using T = ScaleT; + const hn::Repartition di32; + const hn::Half> di16; + const hn::Half> di8; + + const size_t N = hn::Lanes(df); + using VI8 = hn::Vec; + using VF = hn::Vec; + + HWY_DASSERT(packed_ofs % kGroupSize == 0); + HWY_DASSERT(num % 2 * N == 0); + + // Calculate min/max using SIMD + float min_val = hwy::HighestValue(); + float max_val = hwy::LowestValue(); + VF vmin = hn::Set(df, hwy::HighestValue()); + VF vmax = hn::Set(df, hwy::LowestValue()); + + size_t j = 0; + for (; j + N <= num; j += N) { + const VF xi = hn::LoadU(df, raw + j); + vmin = hn::Min(vmin, xi); + vmax = hn::Max(vmax, xi); + } + + min_val = hn::ReduceMin(df, vmin); + max_val = hn::ReduceMax(df, vmax); + + for (; j < num; ++j) { + min_val = HWY_MIN(min_val, raw[j]); + max_val = HWY_MAX(max_val, raw[j]); + } + + // Calculate range, scale and zeropoint + float x_range = max_val - min_val; + x_range = x_range == 0.0f ? 1.0f : x_range; + const float scale_f = 255.0f / x_range; + const float zeropoint_f = static_cast( + static_cast(-scale_f * min_val - 128.0f)); // Correct casting + + const T scale = hwy::ConvertScalarTo(scale_f); + // inv_scale is used for all dequantization. + const T inv_scale = hwy::ConvertScalarTo(1.0f / scale_f); + const T zeropoint = hwy::ConvertScalarTo(zeropoint_f); + memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, sizeof(T)); + memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), &zeropoint, + sizeof(T)); + + const size_t g_num = HWY_MIN(num, kGroupSize); + + VF mul = hn::Set(df, hwy::ConvertScalarTo(scale)); + VF add = hn::Set(df, hwy::ConvertScalarTo(zeropoint)); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + + size_t i = 0; + for (; i + 2 * N <= g_num; i += 2 * N) { + const VI8 val0 = hn::DemoteTo( + di8, + hn::DemoteTo(di16, NearestInt(hn::MulAdd( + mul, hn::LoadU(df, raw + i + 0 * N), add)))); + const VI8 val1 = hn::DemoteTo( + di8, + hn::DemoteTo(di16, NearestInt(hn::MulAdd( + mul, hn::LoadU(df, raw + i + 1 * N), add)))); + + hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N); + hn::StoreU(val1, di8, &packed.ptr->i + current_offset + i + 1 * N); + } + + size_t remaining = g_num - i; + + HWY_DASSERT(remaining < 2 * N); + if (HWY_UNLIKELY(remaining == 0)) return; + + if (remaining > N) { + const VI8 val0 = hn::DemoteTo( + di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd( + mul, hn::LoadU(df, raw + i), add)))); + hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i); + + const size_t remaining1 = remaining - N; + const VI8 val1 = hn::DemoteTo( + di8, + hn::DemoteTo(di16, + NearestInt(hn::MulAdd( + mul, hn::LoadN(df, raw + i + N, remaining1), add)))); + hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N, + remaining1); + } else { // remaining <= N + const VI8 val0 = hn::DemoteTo( + di8, hn::DemoteTo(di16, + NearestInt(hn::MulAdd( + mul, hn::LoadN(df, raw + i, remaining), add)))); + hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining); + } + } + + // Encodes `num` floats from `raw` into `packed`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset + // within it, in number of elements. Scaling values are interleaved with + // int + // values to allow for easier unpacking. + template + static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT raw, + const size_t num, + const PackedSpan& packed, + size_t packed_ofs) { + HWY_ASSERT(packed_ofs % kGroupSize == 0); + + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + + size_t current_offset = packed_ofs; + for (size_t g = 0; g < num_groups; ++g) { + const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); + const float* HWY_RESTRICT g_in = raw + g * kGroupSize; + + QuantizeGroup(df, g_in, g_num, packed, current_offset); + current_offset += g_num; + } + } + + // Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. + template + static HWY_INLINE void Dec2(DBF dbf, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + HWY_ASSERT(packed_ofs % 2 * NF == 0); + + VF raw0_f, raw1_f, raw2_f, raw3_f; + Dec2(df, packed, packed_ofs + 0 * 2 * NF, raw0_f, raw1_f); + Dec2(df, packed, packed_ofs + 1 * 2 * NF, raw2_f, raw3_f); + + raw0 = hn::OrderedDemote2To(dbf, raw0_f, raw1_f); + raw1 = hn::OrderedDemote2To(dbf, raw2_f, raw3_f); + } + + // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. + template + static HWY_INLINE void Dec2(DF df, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + using T = ScaleT; + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + const hn::Rebind df8; + + const size_t N = hn::Lanes(di8); + using VI8 = hn::Vec; + using VF = hn::Vec; + + T inv_scale, zeropoint; + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, + sizeof(T)); + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), + &zeropoint, sizeof(T)); + + float inv_scale_f = hwy::ConvertScalarTo(inv_scale); + float zeropoint_f = hwy::ConvertScalarTo(zeropoint); + + VF inv_scale_vec = hn::Set(df, inv_scale_f); + VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f)); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + + const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + 0 * N); + const VI8 val1 = hn::LoadU(di8, &packed.ptr->i + current_offset + 1 * N); + + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + const VF val1_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1))); + + raw0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + raw1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec); + } + + template > + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, size_t packed_ofs, + Raw* HWY_RESTRICT raw, size_t num) { + if (num == 0) return; + + const size_t N = hn::Lanes(d); + const size_t padded_num = hwy::RoundUpTo(num, N); + if (padded_num > num) { + hwy::ZeroBytes(raw + num, (padded_num - num) * sizeof(Raw)); + } + + size_t current_packed_ofs = packed_ofs; + Raw* HWY_RESTRICT current_raw = raw; + size_t num_to_decompress = num; + + if (size_t within_group = current_packed_ofs % kGroupSize; + within_group != 0) { + const size_t remaining_in_group = kGroupSize - within_group; + const size_t num_in_first_group = + HWY_MIN(num_to_decompress, remaining_in_group); + DequantizeGroup(d, packed, current_packed_ofs, current_raw, + num_in_first_group); + current_packed_ofs += num_in_first_group; + current_raw += num_in_first_group; + num_to_decompress -= num_in_first_group; + } + + if (num_to_decompress == 0) return; + + HWY_DASSERT(current_packed_ofs % kGroupSize == 0); + + const size_t num_full_groups = num_to_decompress / kGroupSize; + for (size_t g = 0; g < num_full_groups; ++g) { + DequantizeGroup(d, packed, current_packed_ofs, current_raw, kGroupSize); + current_packed_ofs += kGroupSize; + current_raw += kGroupSize; + } + + const size_t remaining = num_to_decompress % kGroupSize; + if (remaining != 0) { + DequantizeGroup(d, packed, current_packed_ofs, current_raw, remaining); + } + } +}; // IntCodec + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ diff --git a/compression/int_test.cc b/compression/int_test.cc new file mode 100644 index 0000000..f427384 --- /dev/null +++ b/compression/int_test.cc @@ -0,0 +1,494 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) +#endif + +#include +#include +#include + +#include "util/test_util.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "compression/int_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/int-inl.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +static constexpr size_t kGroupSize = I8Stream::kGroupSize; +static constexpr float kTolerance = 50000.0f; + +// Can encode and decode sub-regions. +// Quantizes and de-quantizes a single (potentially partial) group to check +// that the quantizer is working correctly. +struct TestQuantize { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const size_t total = kGroupSize / 2; // already padded + const hn::ScalableTag df; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(total); + auto dec3 = hwy::AllocateAligned(total); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && dec3 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + IntCodec::QuantizeGroup(df, in.get(), total, int_span, 0); + + IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec1.get(), total); + + const float epsilon = + hwy::ConvertScalarTo(hwy::Epsilon()); + const float tolerance = kTolerance * epsilon; + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec1[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec1[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + } + } + + // Check that ::Enc works correctly as well. + IntCodec::Enc(df, in.get(), total, int_span, 0); + + IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec2.get(), total); + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec2[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + } + } + + // Check that ::DecompressAndZeroPad works correctly for one group as well. + IntCodec::Enc(df, in.get(), total, int_span, 0); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec3.get(), + total); + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec3[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec3[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + HWY_ASSERT(false); + } + } + } +}; + +void TestQuantizeBF16() { hn::ForGEVectors<128, TestQuantize>()(BF16()); } +void TestQuantizeF32() { hn::ForGEVectors<128, TestQuantize>()(float()); } + +// Can encode and decode sub-regions. +// Quantizes and de-quantizes multiple (potentially partial) groups to check +// that DecompressAndZeroPad is working correctly. +struct TestMultiGroup { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = kGroupSize * 2 + kGroupSize / 4; // already padded + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(total); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + const float epsilon = + hwy::ConvertScalarTo(hwy::Epsilon()); + const float tolerance = kTolerance * epsilon; + + // Check that ::DecompressAndZeroPad works correctly for one group as well. + IntCodec::Enc(df, in.get(), total, int_span, 0); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec2.get(), + total); + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec2[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + HWY_ASSERT(false); + } + } + } +}; + +void TestMultiGroupBF16() { hn::ForGEVectors<128, TestMultiGroup>()(BF16()); } +void TestMultiGroupF32() { hn::ForGEVectors<128, TestMultiGroup>()(float()); } + +// Can encode and decode sub-regions. +struct TestOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; // already padded + const size_t kMidLen = 2 * kGroupSize; // length of middle piece + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(kMidLen); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Encode + decode everything + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(), + total); + + MaybeCheckInitialized(dec1.get(), total * sizeof(T)); + + // Overwrite middle with first inputs + const size_t offset = 5 * kGroupSize; + (void)IntCodec::Enc(df, in.get(), kMidLen, int_span, offset); + + // Decoded middle now matches previously decoded first + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, dec2.get(), + kMidLen); + MaybeCheckInitialized(dec2.get(), kMidLen * sizeof(T)); + + for (size_t i = 0; i < kMidLen; ++i) { + HWY_ASSERT(dec1[i] == dec2[i]); + } + } +}; + +void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); } +void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); } + +// Can encode and decode sub-regions. +struct TestUnalignedOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; // already padded + + const int num_unaligned_offsets = 4; + const std::array unaligned_offsets = { + 4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100}; + const std::array num = {4, 16, 32, 64}; + + for (int i = 0; i < num_unaligned_offsets; ++i) { + const size_t unaligned_offset = unaligned_offsets[i]; + const size_t num_decompressed = num[i]; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto i8_stream = + hwy::AllocateAligned(I8Stream::PackedEnd(total)); + auto dec2 = hwy::AllocateAligned(num_decompressed); + HWY_ASSERT(in && dec1 && dec2 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // // Encode + decode everything + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(), + total); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), unaligned_offset, + dec2.get(), num_decompressed); + + for (size_t i = 0; i < num_decompressed; ++i) { + T expected = hwy::ConvertScalarTo(dec1[unaligned_offset + i]); + T actual = hwy::ConvertScalarTo(dec2[i]); + + HWY_ASSERT_EQ(expected, actual); + } + } + } +}; + +void TestUnalignedOffsetBF16() { + hn::ForGEVectors<128, TestUnalignedOffset>()(BF16()); +} +void TestUnalignedOffsetF32() { + hn::ForGEVectors<128, TestUnalignedOffset>()(float()); +} + +// Can encode and decode sub-regions. +// Uses Dec2 to decode all elements in the packed buffer, then +// compares against DecompressAndZeroPad. +struct TestDec2 { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + // incl. partial group to test partial group handling + const size_t total = kGroupSize * 10 + kGroupSize / 2; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec0 = hwy::AllocateAligned(total); + auto dec1 = hwy::AllocateAligned(total); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec0 && dec1 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Non-interleaved encode + decode for comparison + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec0.get(), + total); + + // Encode + decode everything + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + + using V = hn::Vec; + const size_t N = Lanes(d); + + for (size_t i = 0; i < total; i += 2 * N) { + V f0, f1; + IntCodec::Dec2(d, MakeConst(int_span), i, f0, f1); + + hn::StoreU(f0, d, dec1.get() + i + 0 * N); + hn::StoreU(f1, d, dec1.get() + i + 1 * N); + } + + for (size_t i = 0; i < total; ++i) { + if (dec0[i] != dec1[i]) { + fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i, + hwy::ConvertScalarTo(dec0[i]), i, + hwy::ConvertScalarTo(dec1[i])); + } + + HWY_ASSERT(dec0[i] == dec1[i]); + } + } +}; + +void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); } +void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); } + +// Tests that DecompressAndZeroPad fully populates the output array. +// This is intended to catch uninitialized value errors. +struct TestDequantizeAndZeroPad { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::ScalableTag df; + constexpr size_t kSize = 4096; + auto in = hwy::AllocateAligned(kSize); + auto actual_dec = hwy::AllocateAligned(kSize); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(kSize)); + HWY_ASSERT(in && actual_dec && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), kSize); + + // Fill with a known pattern. + for (size_t i = 0; i < kSize; ++i) { + in[i] = static_cast(i) - 128.0f; + } + + IntCodec::Enc(df, in.get(), kSize, int_span, 0); + + // Initialize with a sentinel value to detect if it's overwritten. + const T sentinel = hwy::ConvertScalarTo(-999.0f); + std::fill(actual_dec.get(), actual_dec.get() + kSize, sentinel); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, actual_dec.get(), + kSize); + + MaybeCheckInitialized(actual_dec.get(), kSize * sizeof(T)); + + // Check that all sentinels were overwritten. + for (size_t i = 0; i < kSize; ++i) { + EXPECT_NE(hwy::ConvertScalarTo(actual_dec[i]), + hwy::ConvertScalarTo(sentinel)) + << " at index " << i; + } + } +}; + +void TestAllDequantizeAndZeroPad() { + hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(BF16()); + hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(float()); +} + +// Tests that DecompressAndZeroPad works correctly for small and unaligned +// inputs. This is intended to catch uninitialized value errors in remainder +// handling. +struct TestSmallDequantize { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::ScalableTag df; + constexpr size_t kGroupSize = I8Stream::kGroupSize; + constexpr size_t kMaxNum = kGroupSize * 3; + auto in = hwy::AllocateAligned(kMaxNum); + auto actual_dec = hwy::AllocateAligned(kMaxNum); + auto i8_stream = + hwy::AllocateAligned(I8Stream::PackedEnd(kMaxNum)); + HWY_ASSERT(in && actual_dec && i8_stream); + const auto int_span = + MakeSpan(i8_stream.get(), I8Stream::PackedEnd(kMaxNum)); + + // Fill with a known pattern. + for (size_t i = 0; i < kMaxNum; ++i) { + in[i] = static_cast(i) - 128.0f; + } + + IntCodec::Enc(df, in.get(), kMaxNum, int_span, 0); + + for (size_t num = 1; num < kGroupSize * 2; ++num) { + for (size_t offset = 0; offset < kGroupSize; offset += 16) { + const T sentinel = hwy::ConvertScalarTo(-999.0f); + std::fill(actual_dec.get(), actual_dec.get() + num, sentinel); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, + actual_dec.get(), num); + + MaybeCheckInitialized(actual_dec.get(), num); + + // Check that all sentinels were overwritten. + for (size_t i = 0; i < num; ++i) { + EXPECT_NE(hwy::ConvertScalarTo(actual_dec[i]), + hwy::ConvertScalarTo(sentinel)) + << " at index " << i << " for num=" << num + << " offset=" << offset; + } + } + } + } +}; + +void TestAllSmallDequantize() { + hn::ForGEVectors<128, TestSmallDequantize>()(BF16()); + hn::ForGEVectors<128, TestSmallDequantize>()(float()); +} + +// Tests that DecompressAndZeroPad works correctly for a specific failing input. +struct TestSpecificDequantize { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::ScalableTag df; + constexpr size_t kSize = 737280; + auto in = hwy::AllocateAligned(kSize); + auto actual_dec = hwy::AllocateAligned(kSize); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(kSize)); + HWY_ASSERT(in && actual_dec && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), kSize); + + // Fill with a known pattern. + for (size_t i = 0; i < kSize; ++i) { + in[i] = static_cast(i) - 128.0f; + } + + IntCodec::Enc(df, in.get(), kSize, int_span, 0); + + const size_t num = 64; + const size_t offset = 392704; + const T sentinel = hwy::ConvertScalarTo(-999.0f); + std::fill(actual_dec.get(), actual_dec.get() + num, sentinel); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, + actual_dec.get(), num); + + // Check that all sentinels were overwritten. + for (size_t i = 0; i < num; ++i) { + EXPECT_NE(hwy::ConvertScalarTo(actual_dec[i]), + hwy::ConvertScalarTo(sentinel)) + << " at index " << i << " for num=" << num << " offset=" << offset; + } + } +}; + +void TestAllSpecificDequantize() { + hn::ForGEVectors<128, TestSpecificDequantize>()(BF16()); + hn::ForGEVectors<128, TestSpecificDequantize>()(float()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_BEFORE_TEST(IntTest); +HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestDec2BF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestDec2F32); +HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestAllDequantizeAndZeroPad); +HWY_EXPORT_AND_TEST_P(IntTest, TestAllSmallDequantize); +HWY_EXPORT_AND_TEST_P(IntTest, TestAllSpecificDequantize); +HWY_AFTER_TEST(); +} // namespace gcpp +#endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 5e729cc..5f227ac 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -113,6 +113,9 @@ class SbsWriterImpl : public ISbsWriter { case Type::kF32: InsertT(name, weights, tensor_info); break; + case Type::kI8: + InsertT(name, weights, tensor_info); + break; default: HWY_ABORT("Unsupported destination (compressed) type %s", TypeName(type)); diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 957f0ec..16e6bf9 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -90,6 +90,13 @@ class CompressionTest(absltest.TestCase): info_256, ) + writer.insert( + "tensor_i8", + np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32), + configs.Type.kI8, + info_256, + ) + config = configs.ModelConfig( configs.Model.GEMMA2_2B, configs.Type.kSFP, @@ -140,6 +147,11 @@ class CompressionTest(absltest.TestCase): self.assertEqual(mat.type, configs.Type.kF32) self.assertAlmostEqual(mat.scale, 1.0) + mat = reader.find_mat("tensor_i8") + self.assertEqual(mat.cols, 256) + self.assertEqual(mat.rows, 1) + self.assertEqual(mat.type, configs.Type.kI8) + self.assertAlmostEqual(mat.scale, 1.0) if __name__ == "__main__": absltest.main() diff --git a/compression/types.h b/compression/types.h index c3be52a..8f11591 100644 --- a/compression/types.h +++ b/compression/types.h @@ -89,6 +89,26 @@ struct SfpStream { }; #pragma pack(pop) +#pragma pack(push, 1) +struct I8Stream { + static constexpr size_t kGroupSize = 128; + using ScaleT = hwy::bfloat16_t; + + // Returns number of I8Stream to allocate for the stream, which matches its + // size in bytes. + // TODO: should support other types beyond hwy::float32_t for scale and + // zero-point. + static constexpr size_t PackedEnd(size_t capacity) { + const size_t num_groups = hwy::DivCeil(capacity, kGroupSize); + return (sizeof(ScaleT) * num_groups) + // scale + (sizeof(ScaleT) * num_groups) + // zero-point + capacity; // 1 value per byte + } + + int8_t i; +}; +#pragma pack(pop) + // Non-uniform quantization: a compressed representation of f32 inputs that // supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or // two vectors (for `Decompress2`), and decoding to bf16/f32. @@ -187,18 +207,23 @@ constexpr bool IsNuqStream() { return hwy::IsSame, NuqStream>(); } +template +constexpr bool IsI8Stream() { + return hwy::IsSame, I8Stream>(); +} + template constexpr bool SupportsPointerArithmetic() { - return !IsNuqStream(); + return !IsNuqStream() && !IsI8Stream(); } // Tensor types for loading weights. Not all of these are supported weight // types, some are only used for `Activations`. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 }; +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "u32", "u64"}; +static constexpr const char* kTypeStrings[] = { + "unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -210,6 +235,7 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(double), 8 * sizeof(uint32_t), 8 * sizeof(uint64_t), + 8 * sizeof(I8Stream), }; static inline bool EnumValid(Type type) { @@ -234,6 +260,8 @@ Type TypeEnum() { return Type::kU32; } else if constexpr (hwy::IsSame()) { return Type::kU64; + } else if constexpr (hwy::IsSame()) { + return Type::kI8; } else { HWY_DASSERT(false); return Type::kUnknown; @@ -254,7 +282,9 @@ const char* TypeName() { template constexpr bool IsCompressed() { - return hwy::IsSameEither, SfpStream, NuqStream>(); + return hwy::IsSame, SfpStream>() || + hwy::IsSame, NuqStream>() || + hwy::IsSame, I8Stream>(); } // Returns the number of `MatT` elements required to store `capacity` values, @@ -265,6 +295,8 @@ template constexpr size_t CompressedArrayElements(size_t capacity) { if constexpr (hwy::IsSame, NuqStream>()) { return NuqStream::PackedEnd(capacity); + } else if constexpr (hwy::IsSame, I8Stream>()) { + return I8Stream::PackedEnd(capacity); } else { return capacity; } diff --git a/gemma/attention.cc b/gemma/attention.cc index bf39702..1269e53 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -143,8 +143,8 @@ void SingleDotSoftmaxWeightedSum( // Apply rope and scaling to Q. if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), q, layer.layer_config.qkv_dim, - p, worker); + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, + layer.layer_config.qkv_dim, p, worker); }); } @@ -315,8 +315,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim, - env.ctx.profiler, worker); + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32, + qkv_dim, env.ctx.profiler, worker); }); } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index df6efd1..c6a2fba 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -114,8 +114,8 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, // Apply rope and scaling to Q. if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), q_row, - layer.layer_config.qkv_dim, ctx.profiler, worker); + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, + layer.layer_config.qkv_dim, ctx.profiler, worker); }); } PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ecfbe47..0034f3f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -59,7 +59,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1, return; }; // Has multiplier, Gelu(c1) * c2. - Decompress1AndCompressInplace(DF(), c1, count, c2, + Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); @@ -101,8 +101,9 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1, for (size_t ir = 0; ir < range_r.Num(); ++ir) { Decompress1AndCompressInplace( DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir), - [](DF df, VF v1, VF v2) - HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); + /*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF { + return hn::Mul(v2, Gelu(df, v1)); + }); } } diff --git a/gemma/model_store.cc b/gemma/model_store.cc index a20caf2..2f3e1ec 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -112,6 +112,8 @@ class TypePrefix { return Type::kSFP; case '2': return Type::kNUQ; + case 'I': + return Type::kI8; default: // The other types were not written to pre-2025 files, hence no need to // encode and check for them here. diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h index d2b25d9..6becb29 100644 --- a/gemma/tensor_info.h +++ b/gemma/tensor_info.h @@ -46,7 +46,7 @@ struct TensorInfo { // The highest permissible compression for this tensor. The default is // kNUQ, which provides maximum compression. Other values such as kBF16 // or kF32 can be used to limit the compression to a specific type. - Type min_size = Type::kNUQ; + Type min_size = Type::kI8; // Whether to apply scaled softplus to the data. bool scaled_softplus = false; // Whether the columns or the rows take any extra dimensions. diff --git a/gemma/vit.cc b/gemma/vit.cc index d21be16..abe0a37 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -332,8 +332,8 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0), - vit_model_dim, env.ctx.profiler, + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, + activations.x.Row(0), vit_model_dim, env.ctx.profiler, hwy::Profiler::GlobalIdx()); }); } diff --git a/gemma/weights.cc b/gemma/weights.cc index cd8875b..d871c6f 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -147,15 +147,222 @@ void LayerWeightsPtrs::SplitAttW1() { qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); } +static void HWY_MAYBE_UNUSED InitAttWeightsI8( + const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, + MatPtrT& att_weights, std::vector& mat_owners, + const Allocator& allocator) { + if (!attn_vec_einsum_w.HasPtr()) return; + HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8); + + att_weights.SetType(Type::kI8); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked); + } + + const size_t model_dim = layer_config.model_dim; + const size_t heads = layer_config.heads; + const size_t qkv_dim = layer_config.qkv_dim; + + // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. + hwy::AlignedFreeUniquePtr attn_vec_einsum_w_tmp = + hwy::AllocateAligned(model_dim * heads * qkv_dim); + hwy::AlignedFreeUniquePtr att_weights_tmp = + hwy::AllocateAligned(model_dim * heads * qkv_dim); + + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0, + attn_vec_einsum_w_tmp.get(), + model_dim * heads * qkv_dim); + + for (size_t m = 0; m < model_dim; ++m) { + float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim; + for (size_t h = 0; h < heads; ++h) { + hwy::CopyBytes( + attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim, + out_row + h * qkv_dim, qkv_dim * sizeof(float)); + } + } + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, + work, att_weights.Span(), + /*packed_ofs=*/0, pool); + + att_weights.SetScale(attn_vec_einsum_w.Scale()); +} + +static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config, + MatPtrT& gating_einsum_w, + MatPtrT& gating_einsum_w1, + MatPtrT& gating_einsum_w2, + std::vector& mat_owners, + const Allocator& allocator) { + // Files have both or neither of w1 and w2. + HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); + // w is mutually exclusive with w1 and w2 in the file. + HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr()); + // Done if we already read split tensors. + if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; + // Nothing to do if w is not present. + if (!gating_einsum_w.HasPtr()) return; + + HWY_ASSERT(gating_einsum_w.GetType() == Type::kI8); + + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + const size_t model_dim = gating_einsum_w.Cols(); + HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Cols() == model_dim); + HWY_ASSERT(gating_einsum_w2.Cols() == model_dim); + + gating_einsum_w1.SetType(Type::kI8); + gating_einsum_w2.SetType(Type::kI8); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(gating_einsum_w1, allocator, + MatPadding::kPacked); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(gating_einsum_w2, allocator, + MatPadding::kPacked); + } + + const size_t total_size = gating_einsum_w.Rows() * gating_einsum_w.Cols(); + hwy::AlignedFreeUniquePtr w_tmp = + hwy::AllocateAligned(total_size); + + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, gating_einsum_w.Span(), 0, + w_tmp.get(), total_size); + + const size_t split_size = ff_hidden_dim * model_dim; + float* w1_tmp = w_tmp.get(); + float* w2_tmp = w_tmp.get() + split_size; + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0, + pool); + HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0, + pool); + + gating_einsum_w1.SetScale(1.0f); + gating_einsum_w2.SetScale(1.0f); + + gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols()); +} + +static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config, + MatPtrT& qkv_einsum_w, + MatPtrT& qkv_einsum_w1, + MatPtrT& qkv_einsum_w2, + std::vector& mat_owners, + const Allocator& allocator) { + // w is mutually exclusive with w1 in the file. + HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); + // Done if we already read split tensors. + if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; + // Nothing to do if w is not present. + if (!qkv_einsum_w.HasPtr()) return; + + HWY_ASSERT(qkv_einsum_w.GetType() == Type::kI8); + + const size_t model_dim = qkv_einsum_w.Cols(); + const size_t w1_rows = layer_config.heads * layer_config.qkv_dim; + const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; + HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); + HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); + HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows); + HWY_ASSERT(qkv_einsum_w1.Cols() == model_dim); + HWY_ASSERT(qkv_einsum_w2.Cols() == model_dim); + + qkv_einsum_w1.SetType(Type::kI8); + qkv_einsum_w2.SetType(Type::kI8); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(qkv_einsum_w1, allocator, + MatPadding::kPacked); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(qkv_einsum_w2, allocator, + MatPadding::kPacked); + } + + const size_t total_size = qkv_einsum_w.Rows() * qkv_einsum_w.Cols(); + hwy::AlignedFreeUniquePtr w_tmp = + hwy::AllocateAligned(total_size); + + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, qkv_einsum_w.Span(), 0, w_tmp.get(), + total_size); + + const size_t w1_size = w1_rows * model_dim; + const size_t w2_size = w2_rows * model_dim; + float* w1_tmp = w_tmp.get(); + float* w2_tmp = w_tmp.get() + w1_size; + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool); + HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool); + + qkv_einsum_w1.SetScale(1.0f); + qkv_einsum_w2.SetScale(1.0f); + + qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); +} + // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. void LayerWeightsPtrs::Fixup(std::vector& mat_owners, const Allocator& allocator) { - // TODO(janwas): handle NUQ - InitAttWeights(mat_owners, allocator); - SplitW1(); - SplitAttW1(); + if (attn_vec_einsum_w.GetType() == Type::kI8) { + MatPtrT attn_vec_einsum_w_i8(attn_vec_einsum_w); + MatPtrT att_weights_i8(att_weights); + InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8, + mat_owners, allocator); + attn_vec_einsum_w = attn_vec_einsum_w_i8; + att_weights = att_weights_i8; + } else { + InitAttWeights(mat_owners, allocator); + } + + if (gating_einsum_w.GetType() == Type::kI8) { + MatPtrT gating_einsum_w_i8(gating_einsum_w); + MatPtrT gating_einsum_w1_i8(gating_einsum_w1); + MatPtrT gating_einsum_w2_i8(gating_einsum_w2); + SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8, + gating_einsum_w2_i8, mat_owners, allocator); + gating_einsum_w = gating_einsum_w_i8; + gating_einsum_w1 = gating_einsum_w1_i8; + gating_einsum_w2 = gating_einsum_w2_i8; + } else { + SplitW1(); + } + + if (qkv_einsum_w.GetType() == Type::kI8) { + MatPtrT qkv_einsum_w_i8(qkv_einsum_w); + MatPtrT qkv_einsum_w1_i8(qkv_einsum_w1); + MatPtrT qkv_einsum_w2_i8(qkv_einsum_w2); + SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8, + qkv_einsum_w2_i8, mat_owners, allocator); + qkv_einsum_w = qkv_einsum_w_i8; + qkv_einsum_w1 = qkv_einsum_w1_i8; + qkv_einsum_w2 = qkv_einsum_w2_i8; + } else { + SplitAttW1(); + } } static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( @@ -427,8 +634,6 @@ static void ReadAllToBF16(const std::vector& tensors, static std::vector MakeBatches( const std::vector& tensors, const uint64_t file_bytes) { PROFILER_ZONE("Startup.Weights.MakeBatches"); - // Batches must be contiguous but blobs are padded, hence at least one - // batch per tensor, and more when tensor rows exceed the batch size. std::vector batches; batches.reserve(tensors.size()); @@ -439,17 +644,28 @@ static std::vector MakeBatches( HWY_ASSERT(range.End() <= file_bytes); batches.emplace_back(offset, range.key_idx); - const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes(); - const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes(); - uint8_t* row_bytes = mat.RowBytes(0); - for (size_t r = 0; r < mat.Rows(); ++r) { - if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch. - batches.emplace_back(offset, range.key_idx); - // Adding to an empty batch is always successful. - HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); + if (mat.IsPacked()) { + HWY_ASSERT(range.bytes == mat.PackedBytes()); + if (!batches.back().Add(mat.Packed(), range.bytes)) { + // This should not happen if tensors are < 2GB. + // If it does, we need to chunk. For now, let's assume it doesn't. + HWY_ABORT("Packed tensor too large for a single IO batch."); + } + offset += range.bytes; + } else { + const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes(); + const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes(); + uint8_t* row_bytes = mat.RowBytes(0); + for (size_t r = 0; r < mat.Rows(); ++r) { + if (!batches.back().Add(row_bytes, + file_bytes_per_row)) { // Full batch. + batches.emplace_back(offset, range.key_idx); + // Adding to an empty batch is always successful. + HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); + } + offset += file_bytes_per_row; + row_bytes += mem_stride_bytes; } - offset += file_bytes_per_row; - row_bytes += mem_stride_bytes; } HWY_ASSERT(offset == range.End()); } diff --git a/ops/matmul_static.h b/ops/matmul_static.h index 6b93d92..d2ab677 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -50,6 +50,7 @@ GEMMA_MATMUL_FOR_B(float) \ GEMMA_MATMUL_FOR_B(NuqStream) \ GEMMA_MATMUL_FOR_B(SfpStream) \ + GEMMA_MATMUL_FOR_B(I8Stream) \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/ops/matmul_static_i8.cc b/ops/matmul_static_i8.cc new file mode 100644 index 0000000..b21bc27 --- /dev/null +++ b/ops/matmul_static_i8.cc @@ -0,0 +1,29 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/matmul_static_i8.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_MATMUL_TB I8Stream +#include "ops/matmul_static-inl.h" \ No newline at end of file diff --git a/ops/ops-inl.h b/ops/ops-inl.h index c966a68..4ff2c7d 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -220,6 +220,7 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, + const size_t w_ofs, OT* HWY_RESTRICT out, const size_t size, hwy::Profiler& p, const size_t worker) { @@ -232,7 +233,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker)); const VF* HWY_RESTRICT pmul = &mul; - Decompress2AndCompressTo(DF(), out, size, x, weight, + Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs, [pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF { const VF m = hn::Mul(*pmul, vx); // (1+weight) * m = m + weight*m = one FMA. @@ -242,13 +243,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, - XT* HWY_RESTRICT inout, - const size_t size, - hwy::Profiler& p, - const size_t worker) { +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout, + const size_t size, hwy::Profiler& p, const size_t worker) { PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace)); - namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -256,7 +254,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker)); const VF* HWY_RESTRICT pmul = &mul; - Decompress1AndCompressInplace(DF(), inout, size, weight, + Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs, [pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF { const VF m = hn::Mul(*pmul, vx); // (1+weight) * m = m + weight*m = one FMA. @@ -489,7 +487,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - Decompress1AndCompressInplace(DF(), out, size, x, + Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0, [&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF { return hn::Add(x, out); }); } @@ -507,8 +505,8 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, cluster_idx, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), - out.Row(token_idx), activations.Cols(), ctx.profiler, - worker); + /*w_ofs=*/0, out.Row(token_idx), activations.Cols(), + ctx.profiler, worker); }); }); } @@ -522,7 +520,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, CallUpcasted(&weights, [&](const auto* weights_t) { ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, [&](uint64_t token_idx, size_t worker) { - RMSNormInplace(weights_t->PackedScale1(), + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, inout.Row(token_idx), inout.Cols(), ctx.profiler, worker); }); @@ -604,7 +602,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c, const VF vc = hn::Set(DF(), c); const VF* HWY_RESTRICT pc = &vc; - Decompress1AndCompressInplace(DF(), out, size, x, + Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0, [&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF { return hn::MulAdd(x, *pc, out); }); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index dd8e4e8..d46bb5c 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -558,7 +558,8 @@ struct TestRMSNorm { ScalarRMSNorm(vec, weight, expected, kSize); InitProfilerZones(hwy::Profiler::Get()); - RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); + RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(), + /*worker=*/0); for (size_t i = 0; i < kSize; i++) { const float e = hwy::ConvertScalarTo(expected[i]); @@ -593,7 +594,7 @@ struct TestRMSNormInplace { ScalarRMSNorm(expected, weight, expected, kSize); InitProfilerZones(hwy::Profiler::Get()); - RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(), + RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); for (size_t i = 0; i < kSize; i++) { diff --git a/python/configs.cc b/python/configs.cc index 086c691..e544bb0 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -53,7 +53,11 @@ PYBIND11_MODULE(configs, py_module) { .value("kF32", Type::kF32) .value("kBF16", Type::kBF16) .value("kSFP", Type::kSFP) - .value("kNUQ", Type::kNUQ); + .value("kNUQ", Type::kNUQ) + .value("kF64", Type::kF64) + .value("kU32", Type::kU32) + .value("kU64", Type::kU64) + .value("kI8", Type::kI8); enum_(py_module, "LayerAttentionType") .value("kGemma", LayerAttentionType::kGemma) diff --git a/util/basics.h b/util/basics.h index 0211a0e..5a7f0d5 100644 --- a/util/basics.h +++ b/util/basics.h @@ -59,6 +59,25 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) { #endif } +static inline void MaybePrintInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + __msan_print_shadow(ptr, size); +#else + (void)ptr; + (void)size; +#endif +} + +static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + return __msan_test_shadow(ptr, size); +#else + (void)ptr; + (void)size; + return 0; +#endif +} + // Shared between gemma.h and ops-inl.h. #pragma pack(push, 1) struct TokenAndProb { diff --git a/util/mat.cc b/util/mat.cc index f81767d..6d9c9bf 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -80,11 +80,13 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes, void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator, MatPadding padding) { - const bool is_nuq = mat.GetType() == Type::kNUQ; - if (is_nuq) padding = MatPadding::kPacked; + const bool is_compressed_and_packed = + mat.GetType() == Type::kNUQ || mat.GetType() == Type::kI8; + if (is_compressed_and_packed) padding = MatPadding::kPacked; const size_t stride = Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes()); - const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; + const size_t num = + is_compressed_and_packed ? mat.PackedBytes() : mat.Rows() * stride; // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // might not be enough, hence add extra. `MatT` is at least one byte, which // is half of BF16, hence adding `VectorBytes` *elements* is enough. diff --git a/util/mat.h b/util/mat.h index 6f9a243..59eceaa 100644 --- a/util/mat.h +++ b/util/mat.h @@ -240,6 +240,8 @@ class MatPtr : public IFields { // `CompressedArrayElements` is a wrapper function that has the same // effect, but that requires a template argument, not `type`. num_elements = NuqStream::PackedEnd(num_elements); + } else if (type == Type::kI8) { + num_elements = I8Stream::PackedEnd(num_elements); } return num_elements; } @@ -324,7 +326,8 @@ class MatPtrT : public MatPtr { } PackedSpan PaddedSpan() const { - return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride()); + const size_t num = IsPacked() ? num_elements_ : Rows() * Stride(); + return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num); } // For `compress-inl.h` functions, which assume contiguous streams and thus @@ -379,6 +382,9 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, } else if (base->GetType() == Type::kSFP) { const MatPtrT mat(*base); return func(&mat, std::forward(args)...); + } else if (base->GetType() == Type::kI8) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); } @@ -410,6 +416,10 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, const MatPtrT mat1(*base1); const MatPtrT mat2(*base2); return func(&mat1, &mat2, std::forward(args)...); + } else if (base1->GetType() == Type::kI8) { + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType())); }