mirror of https://github.com/google/gemma.cpp.git
Add 8-bit integer quantization (I8Stream) to Gemma.cpp.
PiperOrigin-RevId: 819787856
This commit is contained in:
parent
ee18916abf
commit
503aaddd65
|
|
@ -349,6 +349,7 @@ cc_library(
|
||||||
"ops/matmul_static_f32.cc",
|
"ops/matmul_static_f32.cc",
|
||||||
"ops/matmul_static_nuq.cc",
|
"ops/matmul_static_nuq.cc",
|
||||||
"ops/matmul_static_sfp.cc",
|
"ops/matmul_static_sfp.cc",
|
||||||
|
"ops/matmul_static_i8.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ops/matmul_static.h",
|
"ops/matmul_static.h",
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,7 @@ set(SOURCES
|
||||||
compression/compress.h
|
compression/compress.h
|
||||||
compression/nuq-inl.h
|
compression/nuq-inl.h
|
||||||
compression/sfp-inl.h
|
compression/sfp-inl.h
|
||||||
|
compression/int-inl.h
|
||||||
compression/types.h
|
compression/types.h
|
||||||
compression/test_util-inl.h
|
compression/test_util-inl.h
|
||||||
evals/benchmark_helper.cc
|
evals/benchmark_helper.cc
|
||||||
|
|
@ -109,6 +110,7 @@ set(SOURCES
|
||||||
ops/matmul_static_f32.cc
|
ops/matmul_static_f32.cc
|
||||||
ops/matmul_static_nuq.cc
|
ops/matmul_static_nuq.cc
|
||||||
ops/matmul_static_sfp.cc
|
ops/matmul_static_sfp.cc
|
||||||
|
ops/matmul_static_i8.cc
|
||||||
ops/matmul-inl.h
|
ops/matmul-inl.h
|
||||||
ops/matmul.cc
|
ops/matmul.cc
|
||||||
ops/matmul.h
|
ops/matmul.h
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,37 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "int",
|
||||||
|
textual_hdrs = ["int-inl.h"],
|
||||||
|
deps = [
|
||||||
|
":types",
|
||||||
|
"//:basics",
|
||||||
|
"@highway//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "int_test",
|
||||||
|
size = "small",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["int_test.cc"],
|
||||||
|
features = ["fully_static_link"],
|
||||||
|
linkstatic = True,
|
||||||
|
local_defines = ["HWY_IS_TEST"],
|
||||||
|
# for test_suite.
|
||||||
|
tags = ["hwy_ops_test"],
|
||||||
|
deps = [
|
||||||
|
":distortion",
|
||||||
|
":int",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//:test_util",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:nanobenchmark",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_util",
|
name = "test_util",
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
|
|
@ -144,6 +175,7 @@ cc_library(
|
||||||
textual_hdrs = ["compress-inl.h"],
|
textual_hdrs = ["compress-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
|
":int",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
"//:basics",
|
"//:basics",
|
||||||
|
|
@ -182,6 +214,7 @@ cc_library(
|
||||||
name = "analyze",
|
name = "analyze",
|
||||||
textual_hdrs = ["analyze.h"],
|
textual_hdrs = ["analyze.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":int",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
":types",
|
":types",
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@
|
||||||
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
|
#include "compression/int-inl.h"
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
|
|
||||||
|
|
@ -416,6 +417,34 @@ struct CompressTraits<SfpStream> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Integer quantization.
|
||||||
|
template <>
|
||||||
|
struct CompressTraits<I8Stream> {
|
||||||
|
using Packed = I8Stream;
|
||||||
|
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||||
|
size_t num, CompressPerThread& tls,
|
||||||
|
const PackedSpan<Packed>& packed,
|
||||||
|
const size_t packed_ofs) {
|
||||||
|
IntCodec::Enc(df, raw, num, packed, packed_ofs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class D> // Caller checks this is f32 or bf16
|
||||||
|
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
||||||
|
const size_t packed_ofs, hn::Vec<D>& raw0,
|
||||||
|
hn::Vec<D>& raw1) {
|
||||||
|
IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class D, typename Raw>
|
||||||
|
static HWY_INLINE void DecompressAndZeroPad(
|
||||||
|
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||||
|
Raw* raw, const size_t num) {
|
||||||
|
IntCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Nonuniform quantization, 4.5 bits per element, two separate streams.
|
// Nonuniform quantization, 4.5 bits per element, two separate streams.
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<NuqStream> {
|
struct CompressTraits<NuqStream> {
|
||||||
|
|
@ -737,9 +766,10 @@ template <class DF, typename T, typename T1, class Func>
|
||||||
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||||
size_t num,
|
size_t num,
|
||||||
const T1* HWY_RESTRICT p1,
|
const T1* HWY_RESTRICT p1,
|
||||||
|
const size_t p1_ofs,
|
||||||
Func&& func) {
|
Func&& func) {
|
||||||
const auto packed_inout = MakeSpan(inout, num);
|
const auto packed_inout = MakeSpan(inout, num);
|
||||||
const auto packed1 = MakeSpan(p1, num);
|
const auto packed1 = MakeSpan(p1, p1_ofs + num);
|
||||||
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
|
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -749,7 +779,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||||
VF v0, v1;
|
VF v0, v1;
|
||||||
Decompress2(df, packed_inout, i, v0, v1);
|
Decompress2(df, packed_inout, i, v0, v1);
|
||||||
VF v10, v11;
|
VF v10, v11;
|
||||||
Decompress2(df, packed1, i, v10, v11);
|
Decompress2(df, packed1, p1_ofs + i, v10, v11);
|
||||||
const VF out0 = func(df, v0, v10);
|
const VF out0 = func(df, v0, v10);
|
||||||
const VF out1 = func(df, v1, v11);
|
const VF out1 = func(df, v1, v11);
|
||||||
Compress2(df, out0, out1, packed_inout, i);
|
Compress2(df, out0, out1, packed_inout, i);
|
||||||
|
|
@ -765,7 +795,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||||
hn::Store(hn::Zero(df), df, buf_inout + NF);
|
hn::Store(hn::Zero(df), df, buf_inout + NF);
|
||||||
hn::Store(hn::Zero(df), df, buf1 + NF);
|
hn::Store(hn::Zero(df), df, buf1 + NF);
|
||||||
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
|
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
|
||||||
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
|
DecompressAndZeroPad(df, packed1, p1_ofs + i, buf1, remaining);
|
||||||
const VF v0 = hn::Load(df, buf_inout);
|
const VF v0 = hn::Load(df, buf_inout);
|
||||||
const VF v1 = hn::Load(df, buf_inout + NF);
|
const VF v1 = hn::Load(df, buf_inout + NF);
|
||||||
const VF v10 = hn::Load(df, buf1);
|
const VF v10 = hn::Load(df, buf1);
|
||||||
|
|
@ -827,10 +857,10 @@ template <class DF, typename T, typename T1, typename T2, class Func>
|
||||||
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||||
const T1* HWY_RESTRICT p1,
|
const T1* HWY_RESTRICT p1,
|
||||||
const T2* HWY_RESTRICT p2,
|
const T2* HWY_RESTRICT p2,
|
||||||
Func&& func) {
|
const size_t p2_ofs, Func&& func) {
|
||||||
const auto packed_out = MakeSpan(out, num);
|
const auto packed_out = MakeSpan(out, num);
|
||||||
const auto packed1 = MakeSpan(p1, num);
|
const auto packed1 = MakeSpan(p1, num);
|
||||||
const auto packed2 = MakeSpan(p2, num);
|
const auto packed2 = MakeSpan(p2, p2_ofs + num);
|
||||||
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
|
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -839,7 +869,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||||
VF v10, v11, v20, v21;
|
VF v10, v11, v20, v21;
|
||||||
Decompress2(df, packed1, i, v10, v11);
|
Decompress2(df, packed1, i, v10, v11);
|
||||||
Decompress2(df, packed2, i, v20, v21);
|
Decompress2(df, packed2, p2_ofs + i, v20, v21);
|
||||||
const VF out0 = func(df, v10, v20);
|
const VF out0 = func(df, v10, v20);
|
||||||
const VF out1 = func(df, v11, v21);
|
const VF out1 = func(df, v11, v21);
|
||||||
Compress2(df, out0, out1, packed_out, i);
|
Compress2(df, out0, out1, packed_out, i);
|
||||||
|
|
@ -856,7 +886,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||||
hn::Store(hn::Zero(df), df, buf1 + NF);
|
hn::Store(hn::Zero(df), df, buf1 + NF);
|
||||||
hn::Store(hn::Zero(df), df, buf2 + NF);
|
hn::Store(hn::Zero(df), df, buf2 + NF);
|
||||||
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
|
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
|
||||||
DecompressAndZeroPad(df, packed2, i, buf2, remaining);
|
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
|
||||||
const VF v10 = hn::Load(df, buf1);
|
const VF v10 = hn::Load(df, buf1);
|
||||||
const VF v11 = hn::Load(df, buf1 + NF);
|
const VF v11 = hn::Load(df, buf1 + NF);
|
||||||
const VF v20 = hn::Load(df, buf2);
|
const VF v20 = hn::Load(df, buf2);
|
||||||
|
|
|
||||||
|
|
@ -243,7 +243,7 @@ class TestDecompressAndCompress {
|
||||||
|
|
||||||
// Uses `out` so as not to overwrite `p`.
|
// Uses `out` so as not to overwrite `p`.
|
||||||
Decompress1AndCompressInplace(
|
Decompress1AndCompressInplace(
|
||||||
df, out.get(), num, p1.get(),
|
df, out.get(), num, p1.get(), /*p1_ofs=*/0,
|
||||||
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
|
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
|
||||||
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
|
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
|
||||||
|
|
||||||
|
|
@ -251,9 +251,9 @@ class TestDecompressAndCompress {
|
||||||
[](DF, VF v) HWY_ATTR -> VF { return v; });
|
[](DF, VF v) HWY_ATTR -> VF { return v; });
|
||||||
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
|
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
|
||||||
|
|
||||||
Decompress2AndCompressTo(df, out.get(), num, p.get(), p1.get(),
|
Decompress2AndCompressTo(
|
||||||
[](DF, VF v, VF v1)
|
df, out.get(), num, p.get(), p1.get(), /*p2_ofs=*/0,
|
||||||
HWY_ATTR -> VF { return hn::Add(v, v1); });
|
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
|
||||||
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
|
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
|
||||||
|
|
||||||
Decompress3AndCompressTo(
|
Decompress3AndCompressTo(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,474 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Normal include guard.
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#include "compression/types.h"
|
||||||
|
#include "util/basics.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/print-inl.h"
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
|
||||||
|
|
||||||
|
// Actual per-target include guard.
|
||||||
|
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE) == \
|
||||||
|
defined(HWY_TARGET_TOGGLE)
|
||||||
|
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
|
||||||
|
#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
|
||||||
|
#else
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
// Encode/decode functions.
|
||||||
|
class IntCodec {
|
||||||
|
using ScaleT = hwy::bfloat16_t;
|
||||||
|
static constexpr size_t kGroupSize = I8Stream::kGroupSize;
|
||||||
|
|
||||||
|
// Offset (in bytes) of a group's start for packed_ofs (in elements) within a
|
||||||
|
// set of groups.
|
||||||
|
static constexpr size_t GroupByteOffset(size_t packed_ofs) {
|
||||||
|
const size_t kBytesPerGroup = (2 * sizeof(ScaleT)) + kGroupSize;
|
||||||
|
return (packed_ofs / kGroupSize) * kBytesPerGroup;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||||
|
static HWY_INLINE void DequantizeGroup(
|
||||||
|
DBF dbf, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
|
||||||
|
hwy::bfloat16_t* HWY_RESTRICT raw, size_t num) {
|
||||||
|
using T = ScaleT;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
const hn::Rebind<int32_t, decltype(df)> di32;
|
||||||
|
const hn::Rebind<int16_t, decltype(di32)> di16;
|
||||||
|
const hn::Rebind<int8_t, decltype(di16)> di8;
|
||||||
|
const hn::Twice<hn::Rebind<hwy::bfloat16_t, decltype(df)>> dbf16;
|
||||||
|
|
||||||
|
const size_t N = hn::Lanes(di8);
|
||||||
|
const size_t N16 = hn::Lanes(dbf16);
|
||||||
|
using VI8 = hn::Vec<decltype(di8)>;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
T inv_scale, zeropoint;
|
||||||
|
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
|
||||||
|
sizeof(T));
|
||||||
|
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
|
||||||
|
&zeropoint, sizeof(T));
|
||||||
|
|
||||||
|
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
|
||||||
|
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
|
||||||
|
|
||||||
|
VF inv_scale_vec = hn::Set(df, inv_scale_f);
|
||||||
|
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
|
||||||
|
|
||||||
|
// Then iterate over remainder of packed, extracting num / N vectors and
|
||||||
|
// inserting into raw.
|
||||||
|
const size_t g_num = HWY_MIN(num, kGroupSize);
|
||||||
|
|
||||||
|
const size_t current_offset = GroupByteOffset(packed_ofs) +
|
||||||
|
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
|
||||||
|
size_t i = 0;
|
||||||
|
for (i = 0; i + 4 * N <= g_num; i += 4 * N) {
|
||||||
|
const VI8 val0 =
|
||||||
|
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N);
|
||||||
|
const VI8 val1 =
|
||||||
|
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N);
|
||||||
|
const VI8 val2 =
|
||||||
|
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 2 * N);
|
||||||
|
const VI8 val3 =
|
||||||
|
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 3 * N);
|
||||||
|
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
const VF val1_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
|
||||||
|
const VF val2_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val2)));
|
||||||
|
const VF val3_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val3)));
|
||||||
|
|
||||||
|
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
|
||||||
|
VF dequantized_val2 = hn::MulAdd(inv_scale_vec, val2_f, zeroscale_vec);
|
||||||
|
VF dequantized_val3 = hn::MulAdd(inv_scale_vec, val3_f, zeroscale_vec);
|
||||||
|
|
||||||
|
hn::StoreU(
|
||||||
|
hn::OrderedDemote2To(dbf16, dequantized_val0, dequantized_val1),
|
||||||
|
dbf16, raw + i + 0 * N16);
|
||||||
|
hn::StoreU(
|
||||||
|
hn::OrderedDemote2To(dbf16, dequantized_val2, dequantized_val3),
|
||||||
|
dbf16, raw + i + 1 * N16);
|
||||||
|
}
|
||||||
|
for (; i + N <= g_num; i += N) {
|
||||||
|
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i);
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
const hn::Rebind<hwy::bfloat16_t, decltype(df)> dbf_half;
|
||||||
|
hn::StoreU(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i);
|
||||||
|
}
|
||||||
|
if (i < g_num) {
|
||||||
|
const size_t remaining = g_num - i;
|
||||||
|
const VI8 val0 =
|
||||||
|
hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining);
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
const hn::Rebind<hwy::bfloat16_t, decltype(df)> dbf_half;
|
||||||
|
hn::StoreN(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i,
|
||||||
|
remaining);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dequantizes `num` floats from `packed` into `raw`. `packed` points to
|
||||||
|
// compressed storage and `packed_ofs` indicates the destination offset
|
||||||
|
// within it, in number of elements. Scaling values are interleaved with int
|
||||||
|
// values to allow for easier unpacking.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE void DequantizeGroup(
|
||||||
|
DF df, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
|
||||||
|
float* HWY_RESTRICT raw, size_t num) {
|
||||||
|
using T = ScaleT;
|
||||||
|
const hn::Rebind<int32_t, decltype(df)> di32;
|
||||||
|
const hn::Rebind<int16_t, decltype(di32)> di16;
|
||||||
|
const hn::Rebind<int8_t, decltype(di16)> di8;
|
||||||
|
const hn::Rebind<int8_t, decltype(df)> df8;
|
||||||
|
|
||||||
|
const size_t N = hn::Lanes(di8);
|
||||||
|
const size_t N32 = hn::Lanes(df);
|
||||||
|
using VI8 = hn::Vec<decltype(di8)>;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
// HWY_ASSERT(num % 2 * N == 0);
|
||||||
|
|
||||||
|
// Load scale and zero point from the beginning - ensure correct pointer
|
||||||
|
// offset.
|
||||||
|
T inv_scale, zeropoint;
|
||||||
|
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
|
||||||
|
sizeof(T));
|
||||||
|
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
|
||||||
|
&zeropoint, sizeof(T));
|
||||||
|
|
||||||
|
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
|
||||||
|
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
|
||||||
|
|
||||||
|
VF inv_scale_vec = hn::Set(df, inv_scale_f);
|
||||||
|
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
|
||||||
|
|
||||||
|
// Then iterate over remainder of packed, extracting num / N vectors and
|
||||||
|
// inserting into raw.
|
||||||
|
const size_t g_num = HWY_MIN(num, kGroupSize);
|
||||||
|
|
||||||
|
const size_t current_offset = GroupByteOffset(packed_ofs) +
|
||||||
|
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i + 2 * N <= g_num; i += 2 * N) {
|
||||||
|
const VI8 val0 =
|
||||||
|
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N);
|
||||||
|
const VI8 val1 =
|
||||||
|
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N);
|
||||||
|
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
const VF val1_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
|
||||||
|
|
||||||
|
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
|
||||||
|
|
||||||
|
hn::StoreU(dequantized_val0, df, raw + i + 0 * N32);
|
||||||
|
hn::StoreU(dequantized_val1, df, raw + i + 1 * N32);
|
||||||
|
}
|
||||||
|
for (; i + N <= g_num; i += N) {
|
||||||
|
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i);
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
hn::StoreU(dequantized_val0, df, raw + i);
|
||||||
|
}
|
||||||
|
if (i < g_num) {
|
||||||
|
const size_t remaining = g_num - i;
|
||||||
|
const VI8 val0 =
|
||||||
|
hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining);
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
hn::StoreN(dequantized_val0, df, raw + i, remaining);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantizes `num` floats from `raw` into `packed`. `packed` points to
|
||||||
|
// compressed storage and `packed_ofs` indicates the destination offset
|
||||||
|
// within it, in number of elements. Scaling values are interleaved with
|
||||||
|
// int values to allow for easier unpacking.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE void QuantizeGroup(DF df, const float* HWY_RESTRICT raw,
|
||||||
|
size_t num,
|
||||||
|
const PackedSpan<I8Stream>& packed,
|
||||||
|
size_t packed_ofs) {
|
||||||
|
using T = ScaleT;
|
||||||
|
const hn::Repartition<int32_t, DF> di32;
|
||||||
|
const hn::Half<hn::Repartition<int16_t, decltype(di32)>> di16;
|
||||||
|
const hn::Half<hn::Repartition<int8_t, decltype(di16)>> di8;
|
||||||
|
|
||||||
|
const size_t N = hn::Lanes(df);
|
||||||
|
using VI8 = hn::Vec<decltype(di8)>;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
HWY_DASSERT(packed_ofs % kGroupSize == 0);
|
||||||
|
HWY_DASSERT(num % 2 * N == 0);
|
||||||
|
|
||||||
|
// Calculate min/max using SIMD
|
||||||
|
float min_val = hwy::HighestValue<float>();
|
||||||
|
float max_val = hwy::LowestValue<float>();
|
||||||
|
VF vmin = hn::Set(df, hwy::HighestValue<float>());
|
||||||
|
VF vmax = hn::Set(df, hwy::LowestValue<float>());
|
||||||
|
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j + N <= num; j += N) {
|
||||||
|
const VF xi = hn::LoadU(df, raw + j);
|
||||||
|
vmin = hn::Min(vmin, xi);
|
||||||
|
vmax = hn::Max(vmax, xi);
|
||||||
|
}
|
||||||
|
|
||||||
|
min_val = hn::ReduceMin(df, vmin);
|
||||||
|
max_val = hn::ReduceMax(df, vmax);
|
||||||
|
|
||||||
|
for (; j < num; ++j) {
|
||||||
|
min_val = HWY_MIN(min_val, raw[j]);
|
||||||
|
max_val = HWY_MAX(max_val, raw[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate range, scale and zeropoint
|
||||||
|
float x_range = max_val - min_val;
|
||||||
|
x_range = x_range == 0.0f ? 1.0f : x_range;
|
||||||
|
const float scale_f = 255.0f / x_range;
|
||||||
|
const float zeropoint_f = static_cast<float>(
|
||||||
|
static_cast<int32_t>(-scale_f * min_val - 128.0f)); // Correct casting
|
||||||
|
|
||||||
|
const T scale = hwy::ConvertScalarTo<T>(scale_f);
|
||||||
|
// inv_scale is used for all dequantization.
|
||||||
|
const T inv_scale = hwy::ConvertScalarTo<T>(1.0f / scale_f);
|
||||||
|
const T zeropoint = hwy::ConvertScalarTo<T>(zeropoint_f);
|
||||||
|
memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, sizeof(T));
|
||||||
|
memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), &zeropoint,
|
||||||
|
sizeof(T));
|
||||||
|
|
||||||
|
const size_t g_num = HWY_MIN(num, kGroupSize);
|
||||||
|
|
||||||
|
VF mul = hn::Set(df, hwy::ConvertScalarTo<float>(scale));
|
||||||
|
VF add = hn::Set(df, hwy::ConvertScalarTo<float>(zeropoint));
|
||||||
|
|
||||||
|
const size_t current_offset = GroupByteOffset(packed_ofs) +
|
||||||
|
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i + 2 * N <= g_num; i += 2 * N) {
|
||||||
|
const VI8 val0 = hn::DemoteTo(
|
||||||
|
di8,
|
||||||
|
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
|
||||||
|
mul, hn::LoadU(df, raw + i + 0 * N), add))));
|
||||||
|
const VI8 val1 = hn::DemoteTo(
|
||||||
|
di8,
|
||||||
|
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
|
||||||
|
mul, hn::LoadU(df, raw + i + 1 * N), add))));
|
||||||
|
|
||||||
|
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N);
|
||||||
|
hn::StoreU(val1, di8, &packed.ptr->i + current_offset + i + 1 * N);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t remaining = g_num - i;
|
||||||
|
|
||||||
|
HWY_DASSERT(remaining < 2 * N);
|
||||||
|
if (HWY_UNLIKELY(remaining == 0)) return;
|
||||||
|
|
||||||
|
if (remaining > N) {
|
||||||
|
const VI8 val0 = hn::DemoteTo(
|
||||||
|
di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd(
|
||||||
|
mul, hn::LoadU(df, raw + i), add))));
|
||||||
|
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i);
|
||||||
|
|
||||||
|
const size_t remaining1 = remaining - N;
|
||||||
|
const VI8 val1 = hn::DemoteTo(
|
||||||
|
di8,
|
||||||
|
hn::DemoteTo(di16,
|
||||||
|
NearestInt(hn::MulAdd(
|
||||||
|
mul, hn::LoadN(df, raw + i + N, remaining1), add))));
|
||||||
|
hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N,
|
||||||
|
remaining1);
|
||||||
|
} else { // remaining <= N
|
||||||
|
const VI8 val0 = hn::DemoteTo(
|
||||||
|
di8, hn::DemoteTo(di16,
|
||||||
|
NearestInt(hn::MulAdd(
|
||||||
|
mul, hn::LoadN(df, raw + i, remaining), add))));
|
||||||
|
hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encodes `num` floats from `raw` into `packed`. `packed` points to
|
||||||
|
// compressed storage and `packed_ofs` indicates the destination offset
|
||||||
|
// within it, in number of elements. Scaling values are interleaved with
|
||||||
|
// int
|
||||||
|
// values to allow for easier unpacking.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT raw,
|
||||||
|
const size_t num,
|
||||||
|
const PackedSpan<I8Stream>& packed,
|
||||||
|
size_t packed_ofs) {
|
||||||
|
HWY_ASSERT(packed_ofs % kGroupSize == 0);
|
||||||
|
|
||||||
|
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
||||||
|
|
||||||
|
size_t current_offset = packed_ofs;
|
||||||
|
for (size_t g = 0; g < num_groups; ++g) {
|
||||||
|
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
|
||||||
|
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
|
||||||
|
|
||||||
|
QuantizeGroup(df, g_in, g_num, packed, current_offset);
|
||||||
|
current_offset += g_num;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
|
||||||
|
// vectors so that we only have to load one group's table.
|
||||||
|
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||||
|
static HWY_INLINE void Dec2(DBF dbf, const PackedSpan<const I8Stream>& packed,
|
||||||
|
const size_t packed_ofs, hn::Vec<DBF>& raw0,
|
||||||
|
hn::Vec<DBF>& raw1) {
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
|
HWY_ASSERT(packed_ofs % 2 * NF == 0);
|
||||||
|
|
||||||
|
VF raw0_f, raw1_f, raw2_f, raw3_f;
|
||||||
|
Dec2(df, packed, packed_ofs + 0 * 2 * NF, raw0_f, raw1_f);
|
||||||
|
Dec2(df, packed, packed_ofs + 1 * 2 * NF, raw2_f, raw3_f);
|
||||||
|
|
||||||
|
raw0 = hn::OrderedDemote2To(dbf, raw0_f, raw1_f);
|
||||||
|
raw1 = hn::OrderedDemote2To(dbf, raw2_f, raw3_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
|
||||||
|
// vectors so that we only have to load one group's table.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE void Dec2(DF df, const PackedSpan<const I8Stream>& packed,
|
||||||
|
const size_t packed_ofs, hn::Vec<DF>& raw0,
|
||||||
|
hn::Vec<DF>& raw1) {
|
||||||
|
using T = ScaleT;
|
||||||
|
const hn::Rebind<int32_t, decltype(df)> di32;
|
||||||
|
const hn::Rebind<int16_t, decltype(di32)> di16;
|
||||||
|
const hn::Rebind<int8_t, decltype(di16)> di8;
|
||||||
|
const hn::Rebind<int8_t, decltype(df)> df8;
|
||||||
|
|
||||||
|
const size_t N = hn::Lanes(di8);
|
||||||
|
using VI8 = hn::Vec<decltype(di8)>;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
T inv_scale, zeropoint;
|
||||||
|
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
|
||||||
|
sizeof(T));
|
||||||
|
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
|
||||||
|
&zeropoint, sizeof(T));
|
||||||
|
|
||||||
|
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
|
||||||
|
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
|
||||||
|
|
||||||
|
VF inv_scale_vec = hn::Set(df, inv_scale_f);
|
||||||
|
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
|
||||||
|
|
||||||
|
const size_t current_offset = GroupByteOffset(packed_ofs) +
|
||||||
|
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
|
||||||
|
|
||||||
|
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + 0 * N);
|
||||||
|
const VI8 val1 = hn::LoadU(di8, &packed.ptr->i + current_offset + 1 * N);
|
||||||
|
|
||||||
|
const VF val0_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
|
||||||
|
const VF val1_f =
|
||||||
|
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
|
||||||
|
|
||||||
|
raw0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
|
||||||
|
raw1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class D, typename Raw = hn::TFromD<D>>
|
||||||
|
static HWY_INLINE void DecompressAndZeroPad(
|
||||||
|
D d, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
|
||||||
|
Raw* HWY_RESTRICT raw, size_t num) {
|
||||||
|
if (num == 0) return;
|
||||||
|
|
||||||
|
const size_t N = hn::Lanes(d);
|
||||||
|
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||||
|
if (padded_num > num) {
|
||||||
|
hwy::ZeroBytes(raw + num, (padded_num - num) * sizeof(Raw));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t current_packed_ofs = packed_ofs;
|
||||||
|
Raw* HWY_RESTRICT current_raw = raw;
|
||||||
|
size_t num_to_decompress = num;
|
||||||
|
|
||||||
|
if (size_t within_group = current_packed_ofs % kGroupSize;
|
||||||
|
within_group != 0) {
|
||||||
|
const size_t remaining_in_group = kGroupSize - within_group;
|
||||||
|
const size_t num_in_first_group =
|
||||||
|
HWY_MIN(num_to_decompress, remaining_in_group);
|
||||||
|
DequantizeGroup(d, packed, current_packed_ofs, current_raw,
|
||||||
|
num_in_first_group);
|
||||||
|
current_packed_ofs += num_in_first_group;
|
||||||
|
current_raw += num_in_first_group;
|
||||||
|
num_to_decompress -= num_in_first_group;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_to_decompress == 0) return;
|
||||||
|
|
||||||
|
HWY_DASSERT(current_packed_ofs % kGroupSize == 0);
|
||||||
|
|
||||||
|
const size_t num_full_groups = num_to_decompress / kGroupSize;
|
||||||
|
for (size_t g = 0; g < num_full_groups; ++g) {
|
||||||
|
DequantizeGroup(d, packed, current_packed_ofs, current_raw, kGroupSize);
|
||||||
|
current_packed_ofs += kGroupSize;
|
||||||
|
current_raw += kGroupSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t remaining = num_to_decompress % kGroupSize;
|
||||||
|
if (remaining != 0) {
|
||||||
|
DequantizeGroup(d, packed, current_packed_ofs, current_raw, remaining);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}; // IntCodec
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
|
||||||
|
|
@ -0,0 +1,494 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
|
||||||
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "util/test_util.h"
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
#include "hwy/tests/test_util.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "compression/int_test.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "compression/int-inl.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
static constexpr size_t kGroupSize = I8Stream::kGroupSize;
|
||||||
|
static constexpr float kTolerance = 50000.0f;
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
// Quantizes and de-quantizes a single (potentially partial) group to check
|
||||||
|
// that the quantizer is working correctly.
|
||||||
|
struct TestQuantize {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const size_t total = kGroupSize / 2; // already padded
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec3 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
|
||||||
|
HWY_ASSERT(in && dec1 && dec2 && dec3 && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
IntCodec::QuantizeGroup(df, in.get(), total, int_span, 0);
|
||||||
|
|
||||||
|
IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec1.get(), total);
|
||||||
|
|
||||||
|
const float epsilon =
|
||||||
|
hwy::ConvertScalarTo<float>(hwy::Epsilon<hwy::bfloat16_t>());
|
||||||
|
const float tolerance = kTolerance * epsilon;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
const float expected_value = static_cast<float>(in[i]);
|
||||||
|
const float actual_value = hwy::ConvertScalarTo<float>(dec1[i]);
|
||||||
|
|
||||||
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"in[%zu] = %f, dec1[%zu] = %f, tolerance = %f, epsilon = %f\n",
|
||||||
|
i, expected_value, i, actual_value, tolerance, epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that ::Enc works correctly as well.
|
||||||
|
IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
|
||||||
|
IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec2.get(), total);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
const float expected_value = static_cast<float>(in[i]);
|
||||||
|
const float actual_value = hwy::ConvertScalarTo<float>(dec2[i]);
|
||||||
|
|
||||||
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n",
|
||||||
|
i, expected_value, i, actual_value, tolerance, epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that ::DecompressAndZeroPad works correctly for one group as well.
|
||||||
|
IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec3.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
const float expected_value = static_cast<float>(in[i]);
|
||||||
|
const float actual_value = hwy::ConvertScalarTo<float>(dec3[i]);
|
||||||
|
|
||||||
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"in[%zu] = %f, dec3[%zu] = %f, tolerance = %f, epsilon = %f\n",
|
||||||
|
i, expected_value, i, actual_value, tolerance, epsilon);
|
||||||
|
HWY_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestQuantizeBF16() { hn::ForGEVectors<128, TestQuantize>()(BF16()); }
|
||||||
|
void TestQuantizeF32() { hn::ForGEVectors<128, TestQuantize>()(float()); }
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
// Quantizes and de-quantizes multiple (potentially partial) groups to check
|
||||||
|
// that DecompressAndZeroPad is working correctly.
|
||||||
|
struct TestMultiGroup {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
const size_t total = kGroupSize * 2 + kGroupSize / 4; // already padded
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
|
||||||
|
HWY_ASSERT(in && dec1 && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
const float epsilon =
|
||||||
|
hwy::ConvertScalarTo<float>(hwy::Epsilon<hwy::bfloat16_t>());
|
||||||
|
const float tolerance = kTolerance * epsilon;
|
||||||
|
|
||||||
|
// Check that ::DecompressAndZeroPad works correctly for one group as well.
|
||||||
|
IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec2.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
const float expected_value = static_cast<float>(in[i]);
|
||||||
|
const float actual_value = hwy::ConvertScalarTo<float>(dec2[i]);
|
||||||
|
|
||||||
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n",
|
||||||
|
i, expected_value, i, actual_value, tolerance, epsilon);
|
||||||
|
HWY_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestMultiGroupBF16() { hn::ForGEVectors<128, TestMultiGroup>()(BF16()); }
|
||||||
|
void TestMultiGroupF32() { hn::ForGEVectors<128, TestMultiGroup>()(float()); }
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
struct TestOffset {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
const size_t total = 10 * kGroupSize; // already padded
|
||||||
|
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||||
|
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
|
||||||
|
HWY_ASSERT(in && dec1 && dec2 && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode + decode everything
|
||||||
|
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
MaybeCheckInitialized(dec1.get(), total * sizeof(T));
|
||||||
|
|
||||||
|
// Overwrite middle with first inputs
|
||||||
|
const size_t offset = 5 * kGroupSize;
|
||||||
|
(void)IntCodec::Enc(df, in.get(), kMidLen, int_span, offset);
|
||||||
|
|
||||||
|
// Decoded middle now matches previously decoded first
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, dec2.get(),
|
||||||
|
kMidLen);
|
||||||
|
MaybeCheckInitialized(dec2.get(), kMidLen * sizeof(T));
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kMidLen; ++i) {
|
||||||
|
HWY_ASSERT(dec1[i] == dec2[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
|
||||||
|
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
struct TestUnalignedOffset {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
const size_t total = 10 * kGroupSize; // already padded
|
||||||
|
|
||||||
|
const int num_unaligned_offsets = 4;
|
||||||
|
const std::array<size_t, num_unaligned_offsets> unaligned_offsets = {
|
||||||
|
4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100};
|
||||||
|
const std::array<size_t, num_unaligned_offsets> num = {4, 16, 32, 64};
|
||||||
|
|
||||||
|
for (int i = 0; i < num_unaligned_offsets; ++i) {
|
||||||
|
const size_t unaligned_offset = unaligned_offsets[i];
|
||||||
|
const size_t num_decompressed = num[i];
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto i8_stream =
|
||||||
|
hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
|
||||||
|
HWY_ASSERT(in && dec1 && dec2 && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
// // Encode + decode everything
|
||||||
|
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), unaligned_offset,
|
||||||
|
dec2.get(), num_decompressed);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_decompressed; ++i) {
|
||||||
|
T expected = hwy::ConvertScalarTo<T>(dec1[unaligned_offset + i]);
|
||||||
|
T actual = hwy::ConvertScalarTo<T>(dec2[i]);
|
||||||
|
|
||||||
|
HWY_ASSERT_EQ(expected, actual);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestUnalignedOffsetBF16() {
|
||||||
|
hn::ForGEVectors<128, TestUnalignedOffset>()(BF16());
|
||||||
|
}
|
||||||
|
void TestUnalignedOffsetF32() {
|
||||||
|
hn::ForGEVectors<128, TestUnalignedOffset>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
// Uses Dec2 to decode all elements in the packed buffer, then
|
||||||
|
// compares against DecompressAndZeroPad.
|
||||||
|
struct TestDec2 {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
// incl. partial group to test partial group handling
|
||||||
|
const size_t total = kGroupSize * 10 + kGroupSize / 2;
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec0 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
|
||||||
|
HWY_ASSERT(in && dec0 && dec1 && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-interleaved encode + decode for comparison
|
||||||
|
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec0.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
// Encode + decode everything
|
||||||
|
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
|
||||||
|
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; i += 2 * N) {
|
||||||
|
V f0, f1;
|
||||||
|
IntCodec::Dec2(d, MakeConst(int_span), i, f0, f1);
|
||||||
|
|
||||||
|
hn::StoreU(f0, d, dec1.get() + i + 0 * N);
|
||||||
|
hn::StoreU(f1, d, dec1.get() + i + 1 * N);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
if (dec0[i] != dec1[i]) {
|
||||||
|
fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i,
|
||||||
|
hwy::ConvertScalarTo<float>(dec0[i]), i,
|
||||||
|
hwy::ConvertScalarTo<float>(dec1[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_ASSERT(dec0[i] == dec1[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); }
|
||||||
|
void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); }
|
||||||
|
|
||||||
|
// Tests that DecompressAndZeroPad fully populates the output array.
|
||||||
|
// This is intended to catch uninitialized value errors.
|
||||||
|
struct TestDequantizeAndZeroPad {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
constexpr size_t kSize = 4096;
|
||||||
|
auto in = hwy::AllocateAligned<float>(kSize);
|
||||||
|
auto actual_dec = hwy::AllocateAligned<T>(kSize);
|
||||||
|
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kSize));
|
||||||
|
HWY_ASSERT(in && actual_dec && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), kSize);
|
||||||
|
|
||||||
|
// Fill with a known pattern.
|
||||||
|
for (size_t i = 0; i < kSize; ++i) {
|
||||||
|
in[i] = static_cast<float>(i) - 128.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
IntCodec::Enc(df, in.get(), kSize, int_span, 0);
|
||||||
|
|
||||||
|
// Initialize with a sentinel value to detect if it's overwritten.
|
||||||
|
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
|
||||||
|
std::fill(actual_dec.get(), actual_dec.get() + kSize, sentinel);
|
||||||
|
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, actual_dec.get(),
|
||||||
|
kSize);
|
||||||
|
|
||||||
|
MaybeCheckInitialized(actual_dec.get(), kSize * sizeof(T));
|
||||||
|
|
||||||
|
// Check that all sentinels were overwritten.
|
||||||
|
for (size_t i = 0; i < kSize; ++i) {
|
||||||
|
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
|
||||||
|
hwy::ConvertScalarTo<float>(sentinel))
|
||||||
|
<< " at index " << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestAllDequantizeAndZeroPad() {
|
||||||
|
hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(BF16());
|
||||||
|
hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that DecompressAndZeroPad works correctly for small and unaligned
|
||||||
|
// inputs. This is intended to catch uninitialized value errors in remainder
|
||||||
|
// handling.
|
||||||
|
struct TestSmallDequantize {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
constexpr size_t kGroupSize = I8Stream::kGroupSize;
|
||||||
|
constexpr size_t kMaxNum = kGroupSize * 3;
|
||||||
|
auto in = hwy::AllocateAligned<float>(kMaxNum);
|
||||||
|
auto actual_dec = hwy::AllocateAligned<T>(kMaxNum);
|
||||||
|
auto i8_stream =
|
||||||
|
hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kMaxNum));
|
||||||
|
HWY_ASSERT(in && actual_dec && i8_stream);
|
||||||
|
const auto int_span =
|
||||||
|
MakeSpan(i8_stream.get(), I8Stream::PackedEnd(kMaxNum));
|
||||||
|
|
||||||
|
// Fill with a known pattern.
|
||||||
|
for (size_t i = 0; i < kMaxNum; ++i) {
|
||||||
|
in[i] = static_cast<float>(i) - 128.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
IntCodec::Enc(df, in.get(), kMaxNum, int_span, 0);
|
||||||
|
|
||||||
|
for (size_t num = 1; num < kGroupSize * 2; ++num) {
|
||||||
|
for (size_t offset = 0; offset < kGroupSize; offset += 16) {
|
||||||
|
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
|
||||||
|
std::fill(actual_dec.get(), actual_dec.get() + num, sentinel);
|
||||||
|
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset,
|
||||||
|
actual_dec.get(), num);
|
||||||
|
|
||||||
|
MaybeCheckInitialized(actual_dec.get(), num);
|
||||||
|
|
||||||
|
// Check that all sentinels were overwritten.
|
||||||
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
|
||||||
|
hwy::ConvertScalarTo<float>(sentinel))
|
||||||
|
<< " at index " << i << " for num=" << num
|
||||||
|
<< " offset=" << offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestAllSmallDequantize() {
|
||||||
|
hn::ForGEVectors<128, TestSmallDequantize>()(BF16());
|
||||||
|
hn::ForGEVectors<128, TestSmallDequantize>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that DecompressAndZeroPad works correctly for a specific failing input.
|
||||||
|
struct TestSpecificDequantize {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
constexpr size_t kSize = 737280;
|
||||||
|
auto in = hwy::AllocateAligned<float>(kSize);
|
||||||
|
auto actual_dec = hwy::AllocateAligned<T>(kSize);
|
||||||
|
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kSize));
|
||||||
|
HWY_ASSERT(in && actual_dec && i8_stream);
|
||||||
|
const auto int_span = MakeSpan(i8_stream.get(), kSize);
|
||||||
|
|
||||||
|
// Fill with a known pattern.
|
||||||
|
for (size_t i = 0; i < kSize; ++i) {
|
||||||
|
in[i] = static_cast<float>(i) - 128.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
IntCodec::Enc(df, in.get(), kSize, int_span, 0);
|
||||||
|
|
||||||
|
const size_t num = 64;
|
||||||
|
const size_t offset = 392704;
|
||||||
|
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
|
||||||
|
std::fill(actual_dec.get(), actual_dec.get() + num, sentinel);
|
||||||
|
|
||||||
|
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset,
|
||||||
|
actual_dec.get(), num);
|
||||||
|
|
||||||
|
// Check that all sentinels were overwritten.
|
||||||
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
|
||||||
|
hwy::ConvertScalarTo<float>(sentinel))
|
||||||
|
<< " at index " << i << " for num=" << num << " offset=" << offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestAllSpecificDequantize() {
|
||||||
|
hn::ForGEVectors<128, TestSpecificDequantize>()(BF16());
|
||||||
|
hn::ForGEVectors<128, TestSpecificDequantize>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#if HWY_ONCE
|
||||||
|
namespace gcpp {
|
||||||
|
HWY_BEFORE_TEST(IntTest);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetF32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetBF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeF32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeBF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestDec2BF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestDec2F32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupF32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupBF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetBF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetF32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestAllDequantizeAndZeroPad);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestAllSmallDequantize);
|
||||||
|
HWY_EXPORT_AND_TEST_P(IntTest, TestAllSpecificDequantize);
|
||||||
|
HWY_AFTER_TEST();
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // HWY_ONCE
|
||||||
|
|
@ -113,6 +113,9 @@ class SbsWriterImpl : public ISbsWriter {
|
||||||
case Type::kF32:
|
case Type::kF32:
|
||||||
InsertT<float>(name, weights, tensor_info);
|
InsertT<float>(name, weights, tensor_info);
|
||||||
break;
|
break;
|
||||||
|
case Type::kI8:
|
||||||
|
InsertT<I8Stream>(name, weights, tensor_info);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Unsupported destination (compressed) type %s",
|
HWY_ABORT("Unsupported destination (compressed) type %s",
|
||||||
TypeName(type));
|
TypeName(type));
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,13 @@ class CompressionTest(absltest.TestCase):
|
||||||
info_256,
|
info_256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writer.insert(
|
||||||
|
"tensor_i8",
|
||||||
|
np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32),
|
||||||
|
configs.Type.kI8,
|
||||||
|
info_256,
|
||||||
|
)
|
||||||
|
|
||||||
config = configs.ModelConfig(
|
config = configs.ModelConfig(
|
||||||
configs.Model.GEMMA2_2B,
|
configs.Model.GEMMA2_2B,
|
||||||
configs.Type.kSFP,
|
configs.Type.kSFP,
|
||||||
|
|
@ -140,6 +147,11 @@ class CompressionTest(absltest.TestCase):
|
||||||
self.assertEqual(mat.type, configs.Type.kF32)
|
self.assertEqual(mat.type, configs.Type.kF32)
|
||||||
self.assertAlmostEqual(mat.scale, 1.0)
|
self.assertAlmostEqual(mat.scale, 1.0)
|
||||||
|
|
||||||
|
mat = reader.find_mat("tensor_i8")
|
||||||
|
self.assertEqual(mat.cols, 256)
|
||||||
|
self.assertEqual(mat.rows, 1)
|
||||||
|
self.assertEqual(mat.type, configs.Type.kI8)
|
||||||
|
self.assertAlmostEqual(mat.scale, 1.0)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,26 @@ struct SfpStream {
|
||||||
};
|
};
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
|
#pragma pack(push, 1)
|
||||||
|
struct I8Stream {
|
||||||
|
static constexpr size_t kGroupSize = 128;
|
||||||
|
using ScaleT = hwy::bfloat16_t;
|
||||||
|
|
||||||
|
// Returns number of I8Stream to allocate for the stream, which matches its
|
||||||
|
// size in bytes.
|
||||||
|
// TODO: should support other types beyond hwy::float32_t for scale and
|
||||||
|
// zero-point.
|
||||||
|
static constexpr size_t PackedEnd(size_t capacity) {
|
||||||
|
const size_t num_groups = hwy::DivCeil(capacity, kGroupSize);
|
||||||
|
return (sizeof(ScaleT) * num_groups) + // scale
|
||||||
|
(sizeof(ScaleT) * num_groups) + // zero-point
|
||||||
|
capacity; // 1 value per byte
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t i;
|
||||||
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
// Non-uniform quantization: a compressed representation of f32 inputs that
|
// Non-uniform quantization: a compressed representation of f32 inputs that
|
||||||
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
|
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
|
||||||
// two vectors (for `Decompress2`), and decoding to bf16/f32.
|
// two vectors (for `Decompress2`), and decoding to bf16/f32.
|
||||||
|
|
@ -187,18 +207,23 @@ constexpr bool IsNuqStream() {
|
||||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Packed>
|
||||||
|
constexpr bool IsI8Stream() {
|
||||||
|
return hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
constexpr bool SupportsPointerArithmetic() {
|
constexpr bool SupportsPointerArithmetic() {
|
||||||
return !IsNuqStream<Packed>();
|
return !IsNuqStream<Packed>() && !IsI8Stream<Packed>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tensor types for loading weights. Not all of these are supported weight
|
// Tensor types for loading weights. Not all of these are supported weight
|
||||||
// types, some are only used for `Activations`.
|
// types, some are only used for `Activations`.
|
||||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 };
|
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 };
|
||||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||||
// change, though new ones may be added.
|
// change, though new ones may be added.
|
||||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
static constexpr const char* kTypeStrings[] = {
|
||||||
"nuq", "f64", "u32", "u64"};
|
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"};
|
||||||
static constexpr size_t kNumTypes =
|
static constexpr size_t kNumTypes =
|
||||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||||
static constexpr size_t kTypeBits[] = {
|
static constexpr size_t kTypeBits[] = {
|
||||||
|
|
@ -210,6 +235,7 @@ static constexpr size_t kTypeBits[] = {
|
||||||
8 * sizeof(double),
|
8 * sizeof(double),
|
||||||
8 * sizeof(uint32_t),
|
8 * sizeof(uint32_t),
|
||||||
8 * sizeof(uint64_t),
|
8 * sizeof(uint64_t),
|
||||||
|
8 * sizeof(I8Stream),
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool EnumValid(Type type) {
|
static inline bool EnumValid(Type type) {
|
||||||
|
|
@ -234,6 +260,8 @@ Type TypeEnum() {
|
||||||
return Type::kU32;
|
return Type::kU32;
|
||||||
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
|
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
|
||||||
return Type::kU64;
|
return Type::kU64;
|
||||||
|
} else if constexpr (hwy::IsSame<Packed, I8Stream>()) {
|
||||||
|
return Type::kI8;
|
||||||
} else {
|
} else {
|
||||||
HWY_DASSERT(false);
|
HWY_DASSERT(false);
|
||||||
return Type::kUnknown;
|
return Type::kUnknown;
|
||||||
|
|
@ -254,7 +282,9 @@ const char* TypeName() {
|
||||||
|
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
constexpr bool IsCompressed() {
|
constexpr bool IsCompressed() {
|
||||||
return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>();
|
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>() ||
|
||||||
|
hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>() ||
|
||||||
|
hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the number of `MatT` elements required to store `capacity` values,
|
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||||
|
|
@ -265,6 +295,8 @@ template <typename Packed>
|
||||||
constexpr size_t CompressedArrayElements(size_t capacity) {
|
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||||
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||||
return NuqStream::PackedEnd(capacity);
|
return NuqStream::PackedEnd(capacity);
|
||||||
|
} else if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>()) {
|
||||||
|
return I8Stream::PackedEnd(capacity);
|
||||||
} else {
|
} else {
|
||||||
return capacity;
|
return capacity;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -143,8 +143,8 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (layer.query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), q, layer.layer_config.qkv_dim,
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
|
||||||
p, worker);
|
layer.layer_config.qkv_dim, p, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,8 +315,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// Apply further processing to K.
|
// Apply further processing to K.
|
||||||
if (layer.key_norm_scale.HasPtr()) {
|
if (layer.key_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim,
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32,
|
||||||
env.ctx.profiler, worker);
|
qkv_dim, env.ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (layer.query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), q_row,
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
|
||||||
layer.layer_config.qkv_dim, ctx.profiler, worker);
|
layer.layer_config.qkv_dim, ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
// Has multiplier, Gelu(c1) * c2.
|
// Has multiplier, Gelu(c1) * c2.
|
||||||
Decompress1AndCompressInplace(DF(), c1, count, c2,
|
Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0,
|
||||||
[](DF df, VF v1, VF v2) HWY_ATTR -> VF {
|
[](DF df, VF v1, VF v2) HWY_ATTR -> VF {
|
||||||
return hn::Mul(v2, Gelu(df, v1));
|
return hn::Mul(v2, Gelu(df, v1));
|
||||||
});
|
});
|
||||||
|
|
@ -101,8 +101,9 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1,
|
||||||
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
||||||
Decompress1AndCompressInplace(
|
Decompress1AndCompressInplace(
|
||||||
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
|
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
|
||||||
[](DF df, VF v1, VF v2)
|
/*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF {
|
||||||
HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); });
|
return hn::Mul(v2, Gelu(df, v1));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,8 @@ class TypePrefix {
|
||||||
return Type::kSFP;
|
return Type::kSFP;
|
||||||
case '2':
|
case '2':
|
||||||
return Type::kNUQ;
|
return Type::kNUQ;
|
||||||
|
case 'I':
|
||||||
|
return Type::kI8;
|
||||||
default:
|
default:
|
||||||
// The other types were not written to pre-2025 files, hence no need to
|
// The other types were not written to pre-2025 files, hence no need to
|
||||||
// encode and check for them here.
|
// encode and check for them here.
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ struct TensorInfo {
|
||||||
// The highest permissible compression for this tensor. The default is
|
// The highest permissible compression for this tensor. The default is
|
||||||
// kNUQ, which provides maximum compression. Other values such as kBF16
|
// kNUQ, which provides maximum compression. Other values such as kBF16
|
||||||
// or kF32 can be used to limit the compression to a specific type.
|
// or kF32 can be used to limit the compression to a specific type.
|
||||||
Type min_size = Type::kNUQ;
|
Type min_size = Type::kI8;
|
||||||
// Whether to apply scaled softplus to the data.
|
// Whether to apply scaled softplus to the data.
|
||||||
bool scaled_softplus = false;
|
bool scaled_softplus = false;
|
||||||
// Whether the columns or the rows take any extra dimensions.
|
// Whether the columns or the rows take any extra dimensions.
|
||||||
|
|
|
||||||
|
|
@ -332,8 +332,8 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
|
|
||||||
// Apply soft embedding norm before input projection.
|
// Apply soft embedding norm before input projection.
|
||||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0),
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||||
vit_model_dim, env.ctx.profiler,
|
activations.x.Row(0), vit_model_dim, env.ctx.profiler,
|
||||||
hwy::Profiler::GlobalIdx());
|
hwy::Profiler::GlobalIdx());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
224
gemma/weights.cc
224
gemma/weights.cc
|
|
@ -147,16 +147,223 @@ void LayerWeightsPtrs::SplitAttW1() {
|
||||||
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void HWY_MAYBE_UNUSED InitAttWeightsI8(
|
||||||
|
const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w,
|
||||||
|
MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners,
|
||||||
|
const Allocator& allocator) {
|
||||||
|
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||||
|
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8);
|
||||||
|
|
||||||
|
att_weights.SetType(Type::kI8);
|
||||||
|
|
||||||
|
{
|
||||||
|
static std::mutex m;
|
||||||
|
std::lock_guard<std::mutex> lock(m);
|
||||||
|
mat_owners.emplace_back();
|
||||||
|
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t model_dim = layer_config.model_dim;
|
||||||
|
const size_t heads = layer_config.heads;
|
||||||
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
|
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
|
||||||
|
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
|
||||||
|
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||||
|
|
||||||
|
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||||
|
HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0,
|
||||||
|
attn_vec_einsum_w_tmp.get(),
|
||||||
|
model_dim * heads * qkv_dim);
|
||||||
|
|
||||||
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
|
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
|
||||||
|
for (size_t h = 0; h < heads; ++h) {
|
||||||
|
hwy::CopyBytes(
|
||||||
|
attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
|
||||||
|
out_row + h * qkv_dim, qkv_dim * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CompressWorkingSet work;
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
|
||||||
|
work, att_weights.Span(),
|
||||||
|
/*packed_ofs=*/0, pool);
|
||||||
|
|
||||||
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||||
|
}
|
||||||
|
|
||||||
|
static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
|
||||||
|
MatPtrT<I8Stream>& gating_einsum_w,
|
||||||
|
MatPtrT<I8Stream>& gating_einsum_w1,
|
||||||
|
MatPtrT<I8Stream>& gating_einsum_w2,
|
||||||
|
std::vector<MatOwner>& mat_owners,
|
||||||
|
const Allocator& allocator) {
|
||||||
|
// Files have both or neither of w1 and w2.
|
||||||
|
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||||
|
// w is mutually exclusive with w1 and w2 in the file.
|
||||||
|
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
|
||||||
|
// Done if we already read split tensors.
|
||||||
|
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
|
||||||
|
// Nothing to do if w is not present.
|
||||||
|
if (!gating_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
|
HWY_ASSERT(gating_einsum_w.GetType() == Type::kI8);
|
||||||
|
|
||||||
|
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||||
|
const size_t model_dim = gating_einsum_w.Cols();
|
||||||
|
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
|
||||||
|
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
|
||||||
|
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
|
||||||
|
HWY_ASSERT(gating_einsum_w1.Cols() == model_dim);
|
||||||
|
HWY_ASSERT(gating_einsum_w2.Cols() == model_dim);
|
||||||
|
|
||||||
|
gating_einsum_w1.SetType(Type::kI8);
|
||||||
|
gating_einsum_w2.SetType(Type::kI8);
|
||||||
|
|
||||||
|
{
|
||||||
|
static std::mutex m;
|
||||||
|
std::lock_guard<std::mutex> lock(m);
|
||||||
|
mat_owners.emplace_back();
|
||||||
|
mat_owners.back().AllocateFor(gating_einsum_w1, allocator,
|
||||||
|
MatPadding::kPacked);
|
||||||
|
mat_owners.emplace_back();
|
||||||
|
mat_owners.back().AllocateFor(gating_einsum_w2, allocator,
|
||||||
|
MatPadding::kPacked);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t total_size = gating_einsum_w.Rows() * gating_einsum_w.Cols();
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> w_tmp =
|
||||||
|
hwy::AllocateAligned<float>(total_size);
|
||||||
|
|
||||||
|
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||||
|
HWY_NAMESPACE::DecompressAndZeroPad(df, gating_einsum_w.Span(), 0,
|
||||||
|
w_tmp.get(), total_size);
|
||||||
|
|
||||||
|
const size_t split_size = ff_hidden_dim * model_dim;
|
||||||
|
float* w1_tmp = w_tmp.get();
|
||||||
|
float* w2_tmp = w_tmp.get() + split_size;
|
||||||
|
|
||||||
|
CompressWorkingSet work;
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0,
|
||||||
|
pool);
|
||||||
|
HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0,
|
||||||
|
pool);
|
||||||
|
|
||||||
|
gating_einsum_w1.SetScale(1.0f);
|
||||||
|
gating_einsum_w2.SetScale(1.0f);
|
||||||
|
|
||||||
|
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
|
||||||
|
}
|
||||||
|
|
||||||
|
static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
|
||||||
|
MatPtrT<I8Stream>& qkv_einsum_w,
|
||||||
|
MatPtrT<I8Stream>& qkv_einsum_w1,
|
||||||
|
MatPtrT<I8Stream>& qkv_einsum_w2,
|
||||||
|
std::vector<MatOwner>& mat_owners,
|
||||||
|
const Allocator& allocator) {
|
||||||
|
// w is mutually exclusive with w1 in the file.
|
||||||
|
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
|
||||||
|
// Done if we already read split tensors.
|
||||||
|
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
||||||
|
// Nothing to do if w is not present.
|
||||||
|
if (!qkv_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
|
HWY_ASSERT(qkv_einsum_w.GetType() == Type::kI8);
|
||||||
|
|
||||||
|
const size_t model_dim = qkv_einsum_w.Cols();
|
||||||
|
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
|
||||||
|
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
|
||||||
|
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
|
||||||
|
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
||||||
|
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
|
||||||
|
HWY_ASSERT(qkv_einsum_w1.Cols() == model_dim);
|
||||||
|
HWY_ASSERT(qkv_einsum_w2.Cols() == model_dim);
|
||||||
|
|
||||||
|
qkv_einsum_w1.SetType(Type::kI8);
|
||||||
|
qkv_einsum_w2.SetType(Type::kI8);
|
||||||
|
|
||||||
|
{
|
||||||
|
static std::mutex m;
|
||||||
|
std::lock_guard<std::mutex> lock(m);
|
||||||
|
mat_owners.emplace_back();
|
||||||
|
mat_owners.back().AllocateFor(qkv_einsum_w1, allocator,
|
||||||
|
MatPadding::kPacked);
|
||||||
|
mat_owners.emplace_back();
|
||||||
|
mat_owners.back().AllocateFor(qkv_einsum_w2, allocator,
|
||||||
|
MatPadding::kPacked);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t total_size = qkv_einsum_w.Rows() * qkv_einsum_w.Cols();
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> w_tmp =
|
||||||
|
hwy::AllocateAligned<float>(total_size);
|
||||||
|
|
||||||
|
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||||
|
HWY_NAMESPACE::DecompressAndZeroPad(df, qkv_einsum_w.Span(), 0, w_tmp.get(),
|
||||||
|
total_size);
|
||||||
|
|
||||||
|
const size_t w1_size = w1_rows * model_dim;
|
||||||
|
const size_t w2_size = w2_rows * model_dim;
|
||||||
|
float* w1_tmp = w_tmp.get();
|
||||||
|
float* w2_tmp = w_tmp.get() + w1_size;
|
||||||
|
|
||||||
|
CompressWorkingSet work;
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool);
|
||||||
|
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool);
|
||||||
|
|
||||||
|
qkv_einsum_w1.SetScale(1.0f);
|
||||||
|
qkv_einsum_w2.SetScale(1.0f);
|
||||||
|
|
||||||
|
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
||||||
|
}
|
||||||
|
|
||||||
// Must be called after reading weights via `ForEachTensor`.
|
// Must be called after reading weights via `ForEachTensor`.
|
||||||
// TODO: exporters should bake this into the weights already.
|
// TODO: exporters should bake this into the weights already.
|
||||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||||
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||||
const Allocator& allocator) {
|
const Allocator& allocator) {
|
||||||
// TODO(janwas): handle NUQ
|
if (attn_vec_einsum_w.GetType() == Type::kI8) {
|
||||||
|
MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w);
|
||||||
|
MatPtrT<I8Stream> att_weights_i8(att_weights);
|
||||||
|
InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8,
|
||||||
|
mat_owners, allocator);
|
||||||
|
attn_vec_einsum_w = attn_vec_einsum_w_i8;
|
||||||
|
att_weights = att_weights_i8;
|
||||||
|
} else {
|
||||||
InitAttWeights(mat_owners, allocator);
|
InitAttWeights(mat_owners, allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gating_einsum_w.GetType() == Type::kI8) {
|
||||||
|
MatPtrT<I8Stream> gating_einsum_w_i8(gating_einsum_w);
|
||||||
|
MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1);
|
||||||
|
MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2);
|
||||||
|
SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8,
|
||||||
|
gating_einsum_w2_i8, mat_owners, allocator);
|
||||||
|
gating_einsum_w = gating_einsum_w_i8;
|
||||||
|
gating_einsum_w1 = gating_einsum_w1_i8;
|
||||||
|
gating_einsum_w2 = gating_einsum_w2_i8;
|
||||||
|
} else {
|
||||||
SplitW1();
|
SplitW1();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qkv_einsum_w.GetType() == Type::kI8) {
|
||||||
|
MatPtrT<I8Stream> qkv_einsum_w_i8(qkv_einsum_w);
|
||||||
|
MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1);
|
||||||
|
MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2);
|
||||||
|
SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8,
|
||||||
|
qkv_einsum_w2_i8, mat_owners, allocator);
|
||||||
|
qkv_einsum_w = qkv_einsum_w_i8;
|
||||||
|
qkv_einsum_w1 = qkv_einsum_w1_i8;
|
||||||
|
qkv_einsum_w2 = qkv_einsum_w2_i8;
|
||||||
|
} else {
|
||||||
SplitAttW1();
|
SplitAttW1();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
|
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
|
||||||
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
|
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
|
||||||
|
|
@ -427,8 +634,6 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||||
static std::vector<IOBatch> MakeBatches(
|
static std::vector<IOBatch> MakeBatches(
|
||||||
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
|
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
|
||||||
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||||
// Batches must be contiguous but blobs are padded, hence at least one
|
|
||||||
// batch per tensor, and more when tensor rows exceed the batch size.
|
|
||||||
std::vector<IOBatch> batches;
|
std::vector<IOBatch> batches;
|
||||||
batches.reserve(tensors.size());
|
batches.reserve(tensors.size());
|
||||||
|
|
||||||
|
|
@ -439,11 +644,21 @@ static std::vector<IOBatch> MakeBatches(
|
||||||
HWY_ASSERT(range.End() <= file_bytes);
|
HWY_ASSERT(range.End() <= file_bytes);
|
||||||
|
|
||||||
batches.emplace_back(offset, range.key_idx);
|
batches.emplace_back(offset, range.key_idx);
|
||||||
|
if (mat.IsPacked()) {
|
||||||
|
HWY_ASSERT(range.bytes == mat.PackedBytes());
|
||||||
|
if (!batches.back().Add(mat.Packed(), range.bytes)) {
|
||||||
|
// This should not happen if tensors are < 2GB.
|
||||||
|
// If it does, we need to chunk. For now, let's assume it doesn't.
|
||||||
|
HWY_ABORT("Packed tensor too large for a single IO batch.");
|
||||||
|
}
|
||||||
|
offset += range.bytes;
|
||||||
|
} else {
|
||||||
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
|
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
|
||||||
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
|
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
|
||||||
uint8_t* row_bytes = mat.RowBytes(0);
|
uint8_t* row_bytes = mat.RowBytes(0);
|
||||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
|
if (!batches.back().Add(row_bytes,
|
||||||
|
file_bytes_per_row)) { // Full batch.
|
||||||
batches.emplace_back(offset, range.key_idx);
|
batches.emplace_back(offset, range.key_idx);
|
||||||
// Adding to an empty batch is always successful.
|
// Adding to an empty batch is always successful.
|
||||||
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
|
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
|
||||||
|
|
@ -451,6 +666,7 @@ static std::vector<IOBatch> MakeBatches(
|
||||||
offset += file_bytes_per_row;
|
offset += file_bytes_per_row;
|
||||||
row_bytes += mem_stride_bytes;
|
row_bytes += mem_stride_bytes;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
HWY_ASSERT(offset == range.End());
|
HWY_ASSERT(offset == range.End());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@
|
||||||
GEMMA_MATMUL_FOR_B(float) \
|
GEMMA_MATMUL_FOR_B(float) \
|
||||||
GEMMA_MATMUL_FOR_B(NuqStream) \
|
GEMMA_MATMUL_FOR_B(NuqStream) \
|
||||||
GEMMA_MATMUL_FOR_B(SfpStream) \
|
GEMMA_MATMUL_FOR_B(SfpStream) \
|
||||||
|
GEMMA_MATMUL_FOR_B(I8Stream) \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
||||||
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/matmul_static_i8.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_MATMUL_TB I8Stream
|
||||||
|
#include "ops/matmul_static-inl.h"
|
||||||
|
|
@ -220,6 +220,7 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||||
template <typename XT, typename WT, typename OT>
|
template <typename XT, typename WT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
const WT* HWY_RESTRICT weight,
|
const WT* HWY_RESTRICT weight,
|
||||||
|
const size_t w_ofs,
|
||||||
OT* HWY_RESTRICT out,
|
OT* HWY_RESTRICT out,
|
||||||
const size_t size, hwy::Profiler& p,
|
const size_t size, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
|
|
@ -232,7 +233,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker));
|
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker));
|
||||||
const VF* HWY_RESTRICT pmul = &mul;
|
const VF* HWY_RESTRICT pmul = &mul;
|
||||||
|
|
||||||
Decompress2AndCompressTo(DF(), out, size, x, weight,
|
Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs,
|
||||||
[pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF {
|
[pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF {
|
||||||
const VF m = hn::Mul(*pmul, vx);
|
const VF m = hn::Mul(*pmul, vx);
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
|
@ -242,13 +243,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
|
|
||||||
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
|
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
|
||||||
template <typename WT, typename XT>
|
template <typename WT, typename XT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
XT* HWY_RESTRICT inout,
|
const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||||
const size_t size,
|
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||||
hwy::Profiler& p,
|
|
||||||
const size_t worker) {
|
|
||||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -256,7 +254,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
||||||
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker));
|
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker));
|
||||||
const VF* HWY_RESTRICT pmul = &mul;
|
const VF* HWY_RESTRICT pmul = &mul;
|
||||||
|
|
||||||
Decompress1AndCompressInplace(DF(), inout, size, weight,
|
Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs,
|
||||||
[pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF {
|
[pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF {
|
||||||
const VF m = hn::Mul(*pmul, vx);
|
const VF m = hn::Mul(*pmul, vx);
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
|
@ -489,7 +487,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
Decompress1AndCompressInplace(DF(), out, size, x,
|
Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0,
|
||||||
[&](DF /*df*/, VF out, VF x)
|
[&](DF /*df*/, VF out, VF x)
|
||||||
HWY_ATTR -> VF { return hn::Add(x, out); });
|
HWY_ATTR -> VF { return hn::Add(x, out); });
|
||||||
}
|
}
|
||||||
|
|
@ -507,8 +505,8 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||||
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||||
out.Row(token_idx), activations.Cols(), ctx.profiler,
|
/*w_ofs=*/0, out.Row(token_idx), activations.Cols(),
|
||||||
worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -522,7 +520,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(),
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||||
inout.Row(token_idx), inout.Cols(),
|
inout.Row(token_idx), inout.Cols(),
|
||||||
ctx.profiler, worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
@ -604,7 +602,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
|
||||||
const VF vc = hn::Set(DF(), c);
|
const VF vc = hn::Set(DF(), c);
|
||||||
const VF* HWY_RESTRICT pc = &vc;
|
const VF* HWY_RESTRICT pc = &vc;
|
||||||
|
|
||||||
Decompress1AndCompressInplace(DF(), out, size, x,
|
Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0,
|
||||||
[&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF {
|
[&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF {
|
||||||
return hn::MulAdd(x, *pc, out);
|
return hn::MulAdd(x, *pc, out);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -558,7 +558,8 @@ struct TestRMSNorm {
|
||||||
|
|
||||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||||
InitProfilerZones(hwy::Profiler::Get());
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
|
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
|
||||||
|
/*worker=*/0);
|
||||||
|
|
||||||
for (size_t i = 0; i < kSize; i++) {
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||||
|
|
@ -593,7 +594,7 @@ struct TestRMSNormInplace {
|
||||||
|
|
||||||
ScalarRMSNorm(expected, weight, expected, kSize);
|
ScalarRMSNorm(expected, weight, expected, kSize);
|
||||||
InitProfilerZones(hwy::Profiler::Get());
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(),
|
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
|
||||||
/*worker=*/0);
|
/*worker=*/0);
|
||||||
|
|
||||||
for (size_t i = 0; i < kSize; i++) {
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,11 @@ PYBIND11_MODULE(configs, py_module) {
|
||||||
.value("kF32", Type::kF32)
|
.value("kF32", Type::kF32)
|
||||||
.value("kBF16", Type::kBF16)
|
.value("kBF16", Type::kBF16)
|
||||||
.value("kSFP", Type::kSFP)
|
.value("kSFP", Type::kSFP)
|
||||||
.value("kNUQ", Type::kNUQ);
|
.value("kNUQ", Type::kNUQ)
|
||||||
|
.value("kF64", Type::kF64)
|
||||||
|
.value("kU32", Type::kU32)
|
||||||
|
.value("kU64", Type::kU64)
|
||||||
|
.value("kI8", Type::kI8);
|
||||||
|
|
||||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
||||||
.value("kGemma", LayerAttentionType::kGemma)
|
.value("kGemma", LayerAttentionType::kGemma)
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,25 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void MaybePrintInitialized(const void* ptr, size_t size) {
|
||||||
|
#if HWY_IS_MSAN
|
||||||
|
__msan_print_shadow(ptr, size);
|
||||||
|
#else
|
||||||
|
(void)ptr;
|
||||||
|
(void)size;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
|
||||||
|
#if HWY_IS_MSAN
|
||||||
|
return __msan_test_shadow(ptr, size);
|
||||||
|
#else
|
||||||
|
(void)ptr;
|
||||||
|
(void)size;
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Shared between gemma.h and ops-inl.h.
|
// Shared between gemma.h and ops-inl.h.
|
||||||
#pragma pack(push, 1)
|
#pragma pack(push, 1)
|
||||||
struct TokenAndProb {
|
struct TokenAndProb {
|
||||||
|
|
|
||||||
|
|
@ -80,11 +80,13 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
|
||||||
|
|
||||||
void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator,
|
void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator,
|
||||||
MatPadding padding) {
|
MatPadding padding) {
|
||||||
const bool is_nuq = mat.GetType() == Type::kNUQ;
|
const bool is_compressed_and_packed =
|
||||||
if (is_nuq) padding = MatPadding::kPacked;
|
mat.GetType() == Type::kNUQ || mat.GetType() == Type::kI8;
|
||||||
|
if (is_compressed_and_packed) padding = MatPadding::kPacked;
|
||||||
const size_t stride =
|
const size_t stride =
|
||||||
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
|
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
|
||||||
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
|
const size_t num =
|
||||||
|
is_compressed_and_packed ? mat.PackedBytes() : mat.Rows() * stride;
|
||||||
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
||||||
// might not be enough, hence add extra. `MatT` is at least one byte, which
|
// might not be enough, hence add extra. `MatT` is at least one byte, which
|
||||||
// is half of BF16, hence adding `VectorBytes` *elements* is enough.
|
// is half of BF16, hence adding `VectorBytes` *elements* is enough.
|
||||||
|
|
|
||||||
12
util/mat.h
12
util/mat.h
|
|
@ -240,6 +240,8 @@ class MatPtr : public IFields {
|
||||||
// `CompressedArrayElements` is a wrapper function that has the same
|
// `CompressedArrayElements` is a wrapper function that has the same
|
||||||
// effect, but that requires a template argument, not `type`.
|
// effect, but that requires a template argument, not `type`.
|
||||||
num_elements = NuqStream::PackedEnd(num_elements);
|
num_elements = NuqStream::PackedEnd(num_elements);
|
||||||
|
} else if (type == Type::kI8) {
|
||||||
|
num_elements = I8Stream::PackedEnd(num_elements);
|
||||||
}
|
}
|
||||||
return num_elements;
|
return num_elements;
|
||||||
}
|
}
|
||||||
|
|
@ -324,7 +326,8 @@ class MatPtrT : public MatPtr {
|
||||||
}
|
}
|
||||||
|
|
||||||
PackedSpan<const MatT> PaddedSpan() const {
|
PackedSpan<const MatT> PaddedSpan() const {
|
||||||
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
|
const size_t num = IsPacked() ? num_elements_ : Rows() * Stride();
|
||||||
|
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
||||||
|
|
@ -379,6 +382,9 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||||
} else if (base->GetType() == Type::kSFP) {
|
} else if (base->GetType() == Type::kSFP) {
|
||||||
const MatPtrT<SfpStream> mat(*base);
|
const MatPtrT<SfpStream> mat(*base);
|
||||||
return func(&mat, std::forward<Args>(args)...);
|
return func(&mat, std::forward<Args>(args)...);
|
||||||
|
} else if (base->GetType() == Type::kI8) {
|
||||||
|
const MatPtrT<I8Stream> mat(*base);
|
||||||
|
return func(&mat, std::forward<Args>(args)...);
|
||||||
} else {
|
} else {
|
||||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||||
}
|
}
|
||||||
|
|
@ -410,6 +416,10 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||||
const MatPtrT<SfpStream> mat1(*base1);
|
const MatPtrT<SfpStream> mat1(*base1);
|
||||||
const MatPtrT<SfpStream> mat2(*base2);
|
const MatPtrT<SfpStream> mat2(*base2);
|
||||||
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||||
|
} else if (base1->GetType() == Type::kI8) {
|
||||||
|
const MatPtrT<I8Stream> mat1(*base1);
|
||||||
|
const MatPtrT<I8Stream> mat2(*base2);
|
||||||
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||||
} else {
|
} else {
|
||||||
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
|
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue