mirror of https://github.com/google/gemma.cpp.git
Merge 123bf7eebb into 23dac72463
This commit is contained in:
commit
c5c85e09fd
|
|
@ -431,7 +431,7 @@ struct CompressTraits<NuqStream> {
|
||||||
size_t num, CompressPerThread& tls,
|
size_t num, CompressPerThread& tls,
|
||||||
const PackedSpan<Packed>& packed,
|
const PackedSpan<Packed>& packed,
|
||||||
const size_t packed_ofs) {
|
const size_t packed_ofs) {
|
||||||
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
|
NuqCodec::EncInterleaved(df, raw, num, tls.buf, packed, packed_ofs);
|
||||||
|
|
||||||
if (COMPRESS_STATS) {
|
if (COMPRESS_STATS) {
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
|
@ -441,8 +441,8 @@ struct CompressTraits<NuqStream> {
|
||||||
const hn::Repartition<BF16, DF> dbf;
|
const hn::Repartition<BF16, DF> dbf;
|
||||||
const size_t N16 = hn::Lanes(dbf);
|
const size_t N16 = hn::Lanes(dbf);
|
||||||
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
|
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
|
||||||
NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
|
NuqCodec::DecompressAndZeroPadInterleaved(
|
||||||
distorted.get(), num);
|
dbf, MakeConst(packed), packed_ofs, distorted.get(), num);
|
||||||
DistortionStats stats;
|
DistortionStats stats;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
|
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
|
||||||
|
|
@ -455,7 +455,7 @@ struct CompressTraits<NuqStream> {
|
||||||
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
||||||
const size_t packed_ofs, hn::Vec<D>& raw0,
|
const size_t packed_ofs, hn::Vec<D>& raw0,
|
||||||
hn::Vec<D>& raw1) {
|
hn::Vec<D>& raw1) {
|
||||||
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
NuqCodec::Dec2Interleaved(d, packed, packed_ofs, raw0, raw1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store2 is not yet implemented.
|
// Store2 is not yet implemented.
|
||||||
|
|
@ -464,7 +464,7 @@ struct CompressTraits<NuqStream> {
|
||||||
static HWY_INLINE void DecompressAndZeroPad(
|
static HWY_INLINE void DecompressAndZeroPad(
|
||||||
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||||
Raw* raw, const size_t num) {
|
Raw* raw, const size_t num) {
|
||||||
NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
|
NuqCodec::DecompressAndZeroPadInterleaved(d, packed, packed_ofs, raw, num);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ class MatPtr : public IFields {
|
||||||
size_t NumElements() const { return num_elements_; }
|
size_t NumElements() const { return num_elements_; }
|
||||||
|
|
||||||
// Returns the number of bytes in the array.
|
// Returns the number of bytes in the array.
|
||||||
size_t SizeBytes() const { return num_elements_ * element_size_; }
|
virtual size_t SizeBytes() const { return num_elements_ * element_size_; }
|
||||||
|
|
||||||
// Returns the number of rows in the 2-d array (outer dimension).
|
// Returns the number of rows in the 2-d array (outer dimension).
|
||||||
size_t Rows() const { return rows_; }
|
size_t Rows() const { return rows_; }
|
||||||
|
|
@ -248,10 +248,13 @@ class MatPtrT : public MatPtr {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the number of elements in the array. For use when the number of
|
// Returns the number of bytes in the array. Overrides MatPtr::SizeBytes()
|
||||||
// elements is != rows * cols ONLY.
|
// to account for NUQ's differing packed size.
|
||||||
void SetNumElements(size_t num_elements) {
|
size_t SizeBytes() const override {
|
||||||
num_elements_ = CompressedArrayElements<MatT>(num_elements);
|
if (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
|
||||||
|
return NuqStream::PackedEnd(num_elements_);
|
||||||
|
}
|
||||||
|
return num_elements_ * element_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-d Accessor for a specific type but with a dynamic inner dimension.
|
// 2-d Accessor for a specific type but with a dynamic inner dimension.
|
||||||
|
|
@ -334,6 +337,12 @@ class MatStorageT : public MatPtrT<MatT> {
|
||||||
// from the current num_elements_ which was set by the constructor from the
|
// from the current num_elements_ which was set by the constructor from the
|
||||||
// rows and cols.
|
// rows and cols.
|
||||||
void Allocate(size_t num_elements = 0) {
|
void Allocate(size_t num_elements = 0) {
|
||||||
|
// size_t num_elements = 0;
|
||||||
|
// TODO: optimize this check or obviate it.
|
||||||
|
if (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
|
||||||
|
HWY_DASSERT(num_elements == 0);
|
||||||
|
}
|
||||||
|
|
||||||
if (num_elements == 0) {
|
if (num_elements == 0) {
|
||||||
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
|
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||||
|
#include "hwy/profiler.h" // uses SIMD
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -529,12 +530,21 @@ class NuqCodec {
|
||||||
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
|
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Offset (in bytes) of a group's table for packed_ofs (in elements) within a
|
||||||
|
// set of groups.
|
||||||
|
static constexpr size_t TableByteOffset(size_t packed_ofs) {
|
||||||
|
const size_t kBytesPerGroup =
|
||||||
|
(kClusters * sizeof(SfpStream)) + kGroupSize / 2;
|
||||||
|
return (packed_ofs / kGroupSize) * kBytesPerGroup;
|
||||||
|
}
|
||||||
|
|
||||||
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
|
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
|
||||||
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
|
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
|
||||||
// not be available for bf16.
|
// not be available for bf16.
|
||||||
template <class DU, HWY_IF_U16_D(DU)>
|
template <class DU, HWY_IF_U16_D(DU)>
|
||||||
static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
|
static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
|
||||||
hn::Vec<DU>* HWY_RESTRICT tbl1) {
|
hn::Vec<DU>* HWY_RESTRICT tbl1) {
|
||||||
|
PROFILER_FUNC;
|
||||||
// Cap to the table size (kClusters) for decoding SFP - sufficient, and may
|
// Cap to the table size (kClusters) for decoding SFP - sufficient, and may
|
||||||
// be faster than a large vector.
|
// be faster than a large vector.
|
||||||
const hn::CappedTag<BF16, kClusters> d_table;
|
const hn::CappedTag<BF16, kClusters> d_table;
|
||||||
|
|
@ -606,6 +616,81 @@ class NuqCodec {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
// Encodes `num` floats from `raw` into `packed`. `packed` points to
|
||||||
|
// compressed storage and `packed_ofs` indicates the destination offset within
|
||||||
|
// it, in number of elements. Tables are interleaved with indices (clustered
|
||||||
|
// elements) to allow for easier unpacking. Returns the total number of
|
||||||
|
// unused clusters, which is typically zero.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw,
|
||||||
|
const size_t num,
|
||||||
|
NuqStream::ClusterBuf& buf,
|
||||||
|
const PackedSpan<NuqStream>& packed,
|
||||||
|
size_t packed_ofs) {
|
||||||
|
const hn::Repartition<uint16_t, DF> d16;
|
||||||
|
const hn::Repartition<uint8_t, DF> d8;
|
||||||
|
using V16 = hn::Vec<decltype(d16)>;
|
||||||
|
using V8 = hn::Vec<decltype(d8)>;
|
||||||
|
const size_t N16 = hn::Lanes(d16);
|
||||||
|
|
||||||
|
HWY_ASSERT(packed_ofs % kGroupSize == 0);
|
||||||
|
|
||||||
|
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
||||||
|
// TODO: dynamic resize should be removed; it is no longer necessary as
|
||||||
|
// interleaved encoding uses only a single buffer of the same size.
|
||||||
|
buf.Resize(1);
|
||||||
|
|
||||||
|
size_t unused_clusters = 0;
|
||||||
|
size_t current_offset = packed_ofs;
|
||||||
|
for (size_t g = 0; g < num_groups; ++g) {
|
||||||
|
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
|
||||||
|
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
|
||||||
|
|
||||||
|
float* HWY_RESTRICT g_centers = buf.centers.get();
|
||||||
|
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();
|
||||||
|
|
||||||
|
unused_clusters +=
|
||||||
|
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);
|
||||||
|
|
||||||
|
uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset);
|
||||||
|
SfpCodec::Enc(df, buf.centers.get(), kClusters,
|
||||||
|
reinterpret_cast<SfpStream*>(centers));
|
||||||
|
uint8_t* packed_start = centers + kClusters;
|
||||||
|
|
||||||
|
current_offset += g_num;
|
||||||
|
|
||||||
|
HWY_DASSERT(g_num % (4 * N16) == 0);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
HWY_UNROLL(1)
|
||||||
|
for (; i < g_num; i += 4 * N16) {
|
||||||
|
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
|
||||||
|
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
|
||||||
|
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
|
||||||
|
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
|
||||||
|
const V8 nibbles =
|
||||||
|
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
|
||||||
|
hn::StoreU(nibbles, d8, packed_start + i / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t remaining = g_num - i;
|
||||||
|
|
||||||
|
HWY_DASSERT(remaining < 4 * N16);
|
||||||
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
|
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
|
||||||
|
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
|
||||||
|
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
|
||||||
|
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
|
||||||
|
const V8 nibbles =
|
||||||
|
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
|
||||||
|
// i is even, but remaining might not be.
|
||||||
|
hn::StoreN(nibbles, d8, packed_start + i / 2,
|
||||||
|
hwy::DivCeil(remaining, 2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return unused_clusters;
|
||||||
|
}
|
||||||
|
|
||||||
// Encodes `num` floats from `raw`. `packed` points to compressed storage and
|
// Encodes `num` floats from `raw`. `packed` points to compressed storage and
|
||||||
// `packed_ofs` indicates the destination offset within it, in units of float
|
// `packed_ofs` indicates the destination offset within it, in units of float
|
||||||
// values, for parallel encoding by multiple threads. Returns the total
|
// values, for parallel encoding by multiple threads. Returns the total
|
||||||
|
|
@ -733,6 +818,8 @@ class NuqCodec {
|
||||||
raw1 = BitCast(dbf, c1);
|
raw1 = BitCast(dbf, c1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(philculliton): Remove non-interleaved function versions now that
|
||||||
|
// interleaved is working / the default.
|
||||||
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
|
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
|
||||||
// vectors so that we only have to load one group's table.
|
// vectors so that we only have to load one group's table.
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
|
@ -765,6 +852,107 @@ class NuqCodec {
|
||||||
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
|
||||||
|
// vectors so that we only have to load one group's table.
|
||||||
|
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||||
|
static HWY_INLINE void Dec2Interleaved(
|
||||||
|
DBF dbf, const PackedSpan<const NuqStream>& packed,
|
||||||
|
const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||||
|
const D8HFromD16<DBF> d8h;
|
||||||
|
using V16 = hn::Vec<decltype(d16)>;
|
||||||
|
using V8H = hn::Vec<decltype(d8h)>;
|
||||||
|
|
||||||
|
const size_t within_group = packed_ofs % kGroupSize;
|
||||||
|
HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0);
|
||||||
|
const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
|
||||||
|
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
|
||||||
|
|
||||||
|
V16 tbl1 = Zero(d16);
|
||||||
|
const V16 tbl0 = LoadTable(d16, table, &tbl1);
|
||||||
|
|
||||||
|
const V8H nibbles = hn::LoadU(d8h, indices);
|
||||||
|
|
||||||
|
V16 c0, c1;
|
||||||
|
TableLookups(d16, tbl0, tbl1, nibbles, c0, c1);
|
||||||
|
raw0 = BitCast(dbf, c0);
|
||||||
|
raw1 = BitCast(dbf, c1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
|
||||||
|
// vectors so that we only have to load one group's table.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE void Dec2Interleaved(
|
||||||
|
DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs,
|
||||||
|
hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) {
|
||||||
|
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||||
|
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||||
|
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
|
||||||
|
using V8Q = hn::Vec<decltype(d8q)>;
|
||||||
|
using V16 = hn::Vec<decltype(d16)>;
|
||||||
|
|
||||||
|
const size_t within_group = packed_ofs % kGroupSize;
|
||||||
|
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
|
||||||
|
const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
|
||||||
|
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
|
||||||
|
|
||||||
|
V16 tbl1 = Zero(d16);
|
||||||
|
const V16 tbl0 = LoadTable(d16, table, &tbl1);
|
||||||
|
|
||||||
|
// The single-vector TableLookups overload only calls OrderedUnpackU16<0>,
|
||||||
|
// which expects a quarter vector of bytes.
|
||||||
|
const V8Q nibbles = hn::LoadU(d8q, indices);
|
||||||
|
|
||||||
|
// TODO(janwas): From janwas: on AVX-512 I imagine we can get a
|
||||||
|
// bit more speed for this function by changing LoadTable to return floats,
|
||||||
|
// then we could have a single lookup here instead of PromoteUpperTo which
|
||||||
|
// is not cheap.
|
||||||
|
const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles);
|
||||||
|
raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
|
||||||
|
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class D, typename Raw = hn::TFromD<D>>
|
||||||
|
static HWY_INLINE void DecompressAndZeroPadInterleaved(
|
||||||
|
D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs,
|
||||||
|
Raw* HWY_RESTRICT raw, size_t num) {
|
||||||
|
// If unaligned, load elements from the first group and update the args,
|
||||||
|
// from which we compute new tables/indices below.
|
||||||
|
size_t current_offset = packed_ofs;
|
||||||
|
if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) {
|
||||||
|
const uint8_t* tables =
|
||||||
|
&packed.ptr->byte + TableByteOffset(current_offset);
|
||||||
|
const uint8_t* indices = tables + kClusters;
|
||||||
|
const size_t remaining = HWY_MIN(num, kGroupSize - within_group);
|
||||||
|
|
||||||
|
DecPartialGroup(d, tables, indices, raw, remaining);
|
||||||
|
packed_ofs += remaining;
|
||||||
|
current_offset += remaining;
|
||||||
|
raw += remaining;
|
||||||
|
num -= remaining;
|
||||||
|
if (num == 0) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_DASSERT(packed_ofs % kGroupSize == 0);
|
||||||
|
|
||||||
|
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
||||||
|
HWY_UNROLL(1)
|
||||||
|
for (size_t g = 0; g < num_groups - 1; ++g) {
|
||||||
|
const uint8_t* tables =
|
||||||
|
&packed.ptr->byte + TableByteOffset(current_offset);
|
||||||
|
const uint8_t* indices = tables + kClusters;
|
||||||
|
DecWholeGroup(d, tables, indices, raw + g * kGroupSize);
|
||||||
|
current_offset += kGroupSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t g = num_groups - 1;
|
||||||
|
const uint8_t* tables = &packed.ptr->byte + TableByteOffset(current_offset);
|
||||||
|
const uint8_t* indices = tables + kClusters;
|
||||||
|
DecPartialGroup(d, tables, indices, raw + g * kGroupSize,
|
||||||
|
num - g * kGroupSize);
|
||||||
|
}
|
||||||
|
|
||||||
// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
|
// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
|
||||||
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
|
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
|
||||||
// round `num` up to one vector, if it is not already.
|
// round `num` up to one vector, if it is not already.
|
||||||
|
|
@ -955,6 +1143,7 @@ class NuqCodec {
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t remaining = num - i;
|
const size_t remaining = num - i;
|
||||||
|
|
||||||
HWY_DASSERT(remaining < 4 * NF);
|
HWY_DASSERT(remaining < 4 * NF);
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
// i is even, but remaining might not be.
|
// i is even, but remaining might not be.
|
||||||
|
|
|
||||||
|
|
@ -277,6 +277,189 @@ struct TestOffset {
|
||||||
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
|
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
|
||||||
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
|
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
struct TestUnalignedOffset {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
const size_t total = 10 * kGroupSize; // already padded
|
||||||
|
|
||||||
|
const int num_unaligned_offsets = 4;
|
||||||
|
const std::array<size_t, num_unaligned_offsets> unaligned_offsets = {
|
||||||
|
4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100};
|
||||||
|
const std::array<size_t, num_unaligned_offsets> num = {4, 16, 32, 64};
|
||||||
|
|
||||||
|
for (int i = 0; i < num_unaligned_offsets; ++i) {
|
||||||
|
const size_t unaligned_offset = unaligned_offsets[i];
|
||||||
|
const size_t num_decompressed = num[i];
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
|
||||||
|
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||||
|
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode + decode everything
|
||||||
|
NuqStream::ClusterBuf buf;
|
||||||
|
(void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0);
|
||||||
|
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec1.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), unaligned_offset,
|
||||||
|
dec2.get(), num_decompressed);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_decompressed; ++i) {
|
||||||
|
T expected = hwy::ConvertScalarTo<T>(dec1[unaligned_offset + i]);
|
||||||
|
T actual = hwy::ConvertScalarTo<T>(dec2[i]);
|
||||||
|
|
||||||
|
HWY_ASSERT_EQ(f_should_be_correct, f_might_be_wrong);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestUnalignedOffsetBF16() {
|
||||||
|
hn::ForGEVectors<128, TestUnalignedOffset>()(BF16());
|
||||||
|
}
|
||||||
|
void TestUnalignedOffsetF32() {
|
||||||
|
hn::ForGEVectors<128, TestUnalignedOffset>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
// Uses Dec2Interleaved to decode all elements in the packed buffer, then
|
||||||
|
// compares against the non-interleaved decode.
|
||||||
|
struct TestDec2Interleaved {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
// incl. partial group to test partial group handling
|
||||||
|
const size_t total = 1 * kGroupSize + (kGroupSize / 2);
|
||||||
|
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec0 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||||
|
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||||
|
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
|
||||||
|
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-interleaved encode + decode for comparison
|
||||||
|
NuqStream::ClusterBuf buf0;
|
||||||
|
(void)NuqCodec::Enc(df, in.get(), total, buf0, nuq_span, 0);
|
||||||
|
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec0.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
// Encode + decode everything
|
||||||
|
NuqStream::ClusterBuf buf;
|
||||||
|
(void)NuqCodec::EncInterleaved(df, in.get(), total, buf, nuq_span, 0);
|
||||||
|
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; i += 2 * N) {
|
||||||
|
V f0, f1;
|
||||||
|
NuqCodec::Dec2Interleaved(d, MakeConst(nuq_span), i, f0, f1);
|
||||||
|
|
||||||
|
hn::StoreU(f0, d, dec1.get() + i + 0 * N);
|
||||||
|
hn::StoreU(f1, d, dec1.get() + i + 1 * N);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
if (dec0[i] != dec1[i]) {
|
||||||
|
fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i, (float)dec0[i],
|
||||||
|
i, (float)dec1[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_ASSERT(dec0[i] == dec1[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestDec2BF16Interleaved() {
|
||||||
|
hn::ForGEVectors<128, TestDec2Interleaved>()(BF16());
|
||||||
|
}
|
||||||
|
void TestDec2F32Interleaved() {
|
||||||
|
hn::ForGEVectors<128, TestDec2Interleaved>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can encode and decode sub-regions.
|
||||||
|
struct TestOffsetInterleaved {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
const hn::Repartition<float, D> df;
|
||||||
|
const size_t total =
|
||||||
|
10 * kGroupSize +
|
||||||
|
(kGroupSize /
|
||||||
|
2); // adding a partial group to test... partial group handling!
|
||||||
|
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
|
||||||
|
|
||||||
|
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||||
|
auto dec0 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||||
|
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||||
|
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||||
|
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
|
||||||
|
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||||
|
|
||||||
|
hwy::RandomState rng;
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-interleaved encode + decode for comparison
|
||||||
|
NuqStream::ClusterBuf buf0;
|
||||||
|
(void)NuqCodec::Enc(df, in.get(), total, buf0, nuq_span, 0);
|
||||||
|
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec0.get(),
|
||||||
|
total);
|
||||||
|
|
||||||
|
// Encode + decode everything
|
||||||
|
NuqStream::ClusterBuf buf;
|
||||||
|
(void)NuqCodec::EncInterleaved(df, in.get(), total, buf, nuq_span, 0);
|
||||||
|
NuqCodec::DecompressAndZeroPadInterleaved(d, MakeConst(nuq_span), 0,
|
||||||
|
dec1.get(), total);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < total; ++i) {
|
||||||
|
if (dec0[i] != dec1[i]) {
|
||||||
|
fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i, (float)dec0[i],
|
||||||
|
i, (float)dec1[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_ASSERT(dec0[i] == dec1[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overwrite middle with first inputs
|
||||||
|
const size_t offset = 5 * kGroupSize;
|
||||||
|
(void)NuqCodec::EncInterleaved(df, in.get(), kMidLen, buf, nuq_span,
|
||||||
|
offset);
|
||||||
|
|
||||||
|
// Decoded middle now matches previously decoded first
|
||||||
|
NuqCodec::DecompressAndZeroPadInterleaved(d, MakeConst(nuq_span), offset,
|
||||||
|
dec2.get(), kMidLen);
|
||||||
|
for (size_t i = 0; i < kMidLen; ++i) {
|
||||||
|
HWY_ASSERT(dec1[i] == dec2[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestOffsetBF16Interleaved() {
|
||||||
|
hn::ForGEVectors<128, TestOffsetInterleaved>()(BF16());
|
||||||
|
}
|
||||||
|
void TestOffsetF32Interleaved() {
|
||||||
|
hn::ForGEVectors<128, TestOffsetInterleaved>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
struct TestNibble {
|
struct TestNibble {
|
||||||
template <typename T, class D>
|
template <typename T, class D>
|
||||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
|
@ -409,6 +592,12 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2BF16Interleaved);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2F32Interleaved);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetBF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16Interleaved);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32Interleaved);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
|
||||||
|
|
|
||||||
|
|
@ -81,8 +81,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void AllocateAndCompress(const std::string& name,
|
void AllocateAndCompress(const std::string& name,
|
||||||
absl::Span<const float> weights) {
|
absl::Span<const float> weights) {
|
||||||
const size_t num_packed = CompressedArrayElements<Packed>(weights.size());
|
MatPtrT<Packed> storage(name, 1, weights.size());
|
||||||
MatPtrT<Packed> storage(name, 1, num_packed);
|
|
||||||
model_memory_.push_back(storage);
|
model_memory_.push_back(storage);
|
||||||
model_memory_.back().Allocate();
|
model_memory_.back().Allocate();
|
||||||
storage.SetPtr(model_memory_.back());
|
storage.SetPtr(model_memory_.back());
|
||||||
|
|
|
||||||
|
|
@ -159,16 +159,20 @@ struct NuqStream {
|
||||||
|
|
||||||
// Returns offset of packed indices from the start of the stream. This matches
|
// Returns offset of packed indices from the start of the stream. This matches
|
||||||
// the (padded) total table size because table entries are bytes.
|
// the (padded) total table size because table entries are bytes.
|
||||||
|
// TODO(philculliton): Remove when removing non-interleaved functions.
|
||||||
static constexpr size_t PackedStart(size_t capacity) {
|
static constexpr size_t PackedStart(size_t capacity) {
|
||||||
// Round up to avoid cache-line splits when loading indices. No effect on
|
// Round up to avoid cache-line splits when loading indices. No effect on
|
||||||
// size as long as capacity / kGroupSize is a multiple of 4.
|
// size as long as capacity / kGroupSize is a multiple of 4.
|
||||||
return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64);
|
return kClusters; // hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) *
|
||||||
|
// kClusters, 64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns number of NuqStream to allocate for the stream, which matches its
|
// Returns number of NuqStream to allocate for the stream, which matches its
|
||||||
// size in bytes.
|
// size in bytes.
|
||||||
static constexpr size_t PackedEnd(size_t capacity) {
|
static constexpr size_t PackedEnd(size_t capacity) {
|
||||||
return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
|
const size_t num_groups = hwy::DivCeil(capacity, kGroupSize);
|
||||||
|
return (kClusters * num_groups) +
|
||||||
|
hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t byte;
|
uint8_t byte;
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,9 @@
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// TODO: remove compress-inl.h and highway.h when no longer required - i.e.
|
||||||
|
// necessary functionality in Rehape() is moved to weights.cc.
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
|
@ -34,6 +37,7 @@
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -179,6 +183,7 @@ struct LayerWeightsPtrs {
|
||||||
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
||||||
// after loading weights via ForEachTensor.
|
// after loading weights via ForEachTensor.
|
||||||
// TODO: update compression/convert_weights to bake this in.
|
// TODO: update compression/convert_weights to bake this in.
|
||||||
|
// TODO(janwas): shift to weights.cc.
|
||||||
void Reshape(MatStorage* storage) {
|
void Reshape(MatStorage* storage) {
|
||||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||||
|
|
||||||
|
|
@ -194,6 +199,41 @@ struct LayerWeightsPtrs {
|
||||||
storage->Allocate();
|
storage->Allocate();
|
||||||
att_weights.SetPtr(*storage);
|
att_weights.SetPtr(*storage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hwy::IsSame<Weight, NuqStream>()) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
|
||||||
|
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
|
||||||
|
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||||
|
|
||||||
|
HWY_NAMESPACE::DecompressAndZeroPad(
|
||||||
|
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim),
|
||||||
|
0, attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
|
||||||
|
|
||||||
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
|
float* HWY_RESTRICT out_row =
|
||||||
|
att_weights_tmp.get() + m * heads * qkv_dim;
|
||||||
|
for (size_t h = 0; h < heads; ++h) {
|
||||||
|
hwy::CopyBytes(attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim +
|
||||||
|
m * qkv_dim,
|
||||||
|
out_row + h * qkv_dim, qkv_dim * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CompressWorkingSet work;
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
|
||||||
|
HWY_NAMESPACE::Compress(
|
||||||
|
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
|
||||||
|
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
|
||||||
|
/*packed_ofs=*/0, pool);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t m = 0; m < model_dim; ++m) {
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
||||||
for (size_t h = 0; h < heads; ++h) {
|
for (size_t h = 0; h < heads; ++h) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue