Merge branch 'dev' into main

This commit is contained in:
Ola Otesile 2026-01-12 06:18:51 -08:00 committed by GitHub
commit a0bb7b5527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 243 additions and 148 deletions

View File

@ -523,6 +523,7 @@ cc_library(
":configs",
":gemma_args",
":mat",
"//compression:types",
"@highway//:hwy",
],
)
@ -575,6 +576,7 @@ cc_library(
":configs",
":mat",
":threading_context",
"//compression:types",
"//io",
"@highway//:hwy",
"@highway//:profiler",
@ -608,7 +610,9 @@ cc_library(
],
deps = [
":activations",
":basics",
":configs",
":kv_cache",
":mat",
":matmul",
":matmul_env",

View File

@ -293,11 +293,11 @@ class IntCodec {
for (; i + 2 * N <= g_num; i += 2 * N) {
const VI8 val0 = hn::DemoteTo(
di8,
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
hn::DemoteTo(di16, hn::NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i + 0 * N), add))));
const VI8 val1 = hn::DemoteTo(
di8,
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
hn::DemoteTo(di16, hn::NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i + 1 * N), add))));
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N);
@ -311,7 +311,7 @@ class IntCodec {
if (remaining > N) {
const VI8 val0 = hn::DemoteTo(
di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd(
di8, hn::DemoteTo(di16, hn::NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i), add))));
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i);
@ -319,14 +319,14 @@ class IntCodec {
const VI8 val1 = hn::DemoteTo(
di8,
hn::DemoteTo(di16,
NearestInt(hn::MulAdd(
hn::NearestInt(hn::MulAdd(
mul, hn::LoadN(df, raw + i + N, remaining1), add))));
hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N,
remaining1);
} else { // remaining <= N
const VI8 val0 = hn::DemoteTo(
di8, hn::DemoteTo(di16,
NearestInt(hn::MulAdd(
hn::NearestInt(hn::MulAdd(
mul, hn::LoadN(df, raw + i, remaining), add))));
hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining);
}

View File

@ -17,6 +17,10 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
#include <stddef.h>
#include <vector>
// IWYU pragma: begin_exports
#include "compression/distortion.h"
#include "util/mat.h"
@ -153,6 +157,126 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
return compressed;
}
// Returns 1-norm, used for estimating tolerable numerical differences.
inline double MaxRowAbsSum(const MatStorageT<float>& a) {
double max_row_abs_sum = 0.0;
for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Row(r);
double row_abs_sum = 0.0;
for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]);
}
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
}
return max_row_abs_sum;
}
// Returns the maximum absolute value of `a`.
inline float MaxAbs(const MatStorageT<float>& a) {
float max_abs = 0.0f;
for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Row(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
}
}
return max_abs;
}
// B is already transposed.
template <typename TA, typename TB, typename TC>
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs,
int line) {
const hn::ScalableTag<float> df;
const size_t cols = A.Cols();
const size_t B_rows = B.Rows();
// Round up for DecompressAndZeroPad.
MatStorageT<float> a_batch("a_batch", A.Extents(), allocator,
MatPadding::kOdd);
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(), allocator,
MatPadding::kOdd);
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows), allocator,
MatPadding::kOdd);
c_batch.AllocateAndAttachRowPtrs(row_ptrs);
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
allocator, MatPadding::kOdd);
for (size_t m = 0; m < A.Rows(); ++m) {
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Row(m), B_rows);
}
for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n),
cols);
}
// MatMul rounds inputs to BF16, so error is proportional to the max input
// magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
// Dot() uses double-precision summation.
double tolerance = 20 * norm * eps_f32;
// If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the
// F32 to BF16, so add extra tolerance.
if (IsF32<TA>() || IsF32<TB>()) {
tolerance += 2 * max_abs * eps_bf16;
}
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double rel_tolerance =
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
double max_rel = 0.0;
size_t worst_r = 0;
size_t worst_c = 0;
double worst_actual = 0.0;
double worst_expected = 0.0;
size_t num_outside = 0;
for (size_t r = 0; r < A.Rows(); r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
for (size_t c = 0; c < B.Rows(); c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
if (!in_range) {
const double max = HWY_MAX(expected_value, actual_value);
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
max_rel = rel;
++num_outside;
}
}
}
}
if (max_rel > rel_tolerance) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E num_outside %zu\n",
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
tolerance, max_rel, rel_tolerance, num_outside);
}
HWY_ASSERT(hn::AllFalse(
df, hn::IsEitherNaN(hn::Set(df, norm), hn::Set(df, max_abs))));
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -218,12 +218,23 @@ constexpr bool SupportsPointerArithmetic() {
}
// Tensor types for loading weights. Not all of these are supported weight
// types, some are only used for `Activations`.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 };
// types, some are only used for `Activations`. Append-only.
enum class Type {
kUnknown,
kF32,
kBF16,
kSFP,
kNUQ,
kF64,
kU32,
kU64,
kI8,
kU16
};
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"};
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8", "u16"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
@ -236,6 +247,7 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(uint32_t),
8 * sizeof(uint64_t),
8 * sizeof(I8Stream),
8 * sizeof(uint16_t),
};
static inline bool EnumValid(Type type) {
@ -262,6 +274,8 @@ Type TypeEnum() {
return Type::kU64;
} else if constexpr (hwy::IsSame<Packed, I8Stream>()) {
return Type::kI8;
} else if constexpr (hwy::IsSame<Packed, uint16_t>()) {
return Type::kU16;
} else {
HWY_DASSERT(false);
return Type::kUnknown;

View File

@ -35,6 +35,7 @@
namespace gcpp {
typedef std::vector<float, hwy::AlignedAllocator<float>> AlignedFloatVector;
typedef std::vector<BF16, hwy::AlignedAllocator<BF16>> AlignedBF16Vector;
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.

View File

@ -428,6 +428,18 @@ struct ModelConfig : public IFields {
// The third ctor also expects a string returned by this.
std::string Specifier() const;
// Overwrites `max_seq_len` with `new_max_seq_len` and updates all global
// layers' attention window sizes to `new_max_seq_len`. This function must be
// called before instantiating the KVCache object.
void SetMaxSeqLen(size_t new_max_seq_len) {
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
if (attention_window_sizes[i] == max_seq_len) {
attention_window_sizes[i] = new_max_seq_len;
}
}
max_seq_len = new_max_seq_len;
}
void AddLayerConfig(const LayerConfig& layer_config) {
layer_configs.push_back(layer_config);
HWY_ASSERT(layer_configs.size() <= num_layers);
@ -516,7 +528,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
kDeducedKqNorm = 8,
};

View File

@ -22,10 +22,14 @@
#include <cstdlib>
#include <iostream>
#include <limits>
#include <type_traits>
#include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
#include "gemma/kv_cache.h"
#include "gemma/query.h"
#include "util/basics.h"
#include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/base.h"
@ -448,7 +452,7 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
HWY_ALIGN T x_transposed[4 * kMaxLanes];
hn::StoreInterleaved4<DF>(x_0, x_1, x_2, x_3, df, x_transposed);
hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed);
VF4 result = hn::Load(df4, x_transposed);
for (int i = 1; i < kLanes; ++i) {
result = reducer(result, hn::Load(df4, x_transposed + i * 4));

View File

@ -18,6 +18,9 @@
#include "gemma/gemma.h"
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <optional>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
@ -556,10 +559,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
}
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch,
hwy::BitSet4096<>& non_eos,
size_t qi) {
const RuntimeConfig& runtime_config,
QBatch& qbatch,
hwy::BitSet4096<>& non_eos,
size_t qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
@ -745,6 +748,12 @@ Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(args.inference),
aes_ctr_engine_(args.inference.deterministic) {
if (args.inference.seq_len > model_.Config().max_seq_len) {
HWY_WARN(
"Overriding model's max_seq_len=%u with user provided seq_len=%zu.",
model_.Config().max_seq_len, args.inference.seq_len);
model_.MutableConfig().SetMaxSeqLen(args.inference.seq_len);
}
// Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
args.inference, mat_owners_, ctx);

View File

@ -22,8 +22,10 @@
#include <stdio.h>
#include <functional>
#include <optional>
#include <string>
#include "compression/types.h"
#include "gemma/configs.h"
#include "io/io.h" // Path
#include "util/args.h" // IWYU pragma: export

View File

@ -16,8 +16,12 @@
#include "gemma/kv_cache.h"
#include <stddef.h>
#include <algorithm>
#include <utility>
#include <vector>
#include "compression/types.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/mat.h" // ZeroInit

View File

@ -19,12 +19,14 @@
#include <stddef.h>
#include <optional>
#include <utility>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h"
#include "hwy/base.h"
namespace gcpp {

View File

@ -35,8 +35,13 @@ TEST(KVCacheTest, KVCacheToPtrs) {
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
ASSERT_EQ(ptrs.size(), 2);
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
if (caches[0].IsTiled()) {
EXPECT_EQ(ptrs[0].cache, &caches[0]);
EXPECT_EQ(ptrs[1].cache, &caches[1]);
} else {
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
}
}
} // namespace

View File

@ -387,7 +387,7 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) {
}
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping)
Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
if (!ReadMatPtrs(reader)) { // Pre-2025 format.

View File

@ -39,6 +39,8 @@
namespace gcpp {
class Gemma;
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
// except the tensor data, which are read/written by `weights.cc`.
//
@ -60,6 +62,11 @@ class ModelStore {
return config_;
}
ModelConfig& MutableConfig() {
HWY_ASSERT(config_.model != Model::UNKNOWN);
return config_;
}
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
// Returns nullptr if `name` is not available for loading, otherwise the

View File

@ -53,7 +53,7 @@ extern int64_t first_target;
namespace HWY_NAMESPACE {
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
std::vector<double>& times, MMPerKey* per_key) {
std::vector<double>& times) {
std::sort(times.begin(), times.end());
// bench_dnn reports the best and average, but the median seems more
// consistent and resistant to outliers.
@ -134,7 +134,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
}
hwy::PreventElision(keep);
env.ctx.pools.MaybeStopSpinning(use_spinning);
PrintSpeed(A_extents, B_extents, times, per_key);
PrintSpeed(A_extents, B_extents, times);
}
using F32 = float;

View File

@ -58,122 +58,6 @@ extern int64_t first_target;
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const MatStorageT<float>& a) {
double max_row_abs_sum = 0.0;
for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Row(r);
double row_abs_sum = 0.0;
for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]);
}
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
}
return max_row_abs_sum;
}
// Returns the maximum absolute value of `a`.
float MaxAbs(const MatStorageT<float>& a) {
float max_abs = 0.0f;
for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Row(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
}
}
return max_abs;
}
// B is already transposed.
template <typename TA, typename TB, typename TC>
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C,
MatMulEnv& env, int line) {
const hn::ScalableTag<float> df;
const size_t cols = A.Cols();
const size_t B_rows = B.Rows();
// Round up for DecompressAndZeroPad.
MatStorageT<float> a_batch("a_batch", A.Extents(), env.ctx.allocator,
MatPadding::kOdd);
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
env.ctx.allocator, MatPadding::kOdd);
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
env.ctx.allocator, MatPadding::kOdd);
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
env.ctx.allocator, MatPadding::kOdd);
for (size_t m = 0; m < A.Rows(); ++m) {
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Row(m), B_rows);
}
for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n),
cols);
}
// MatMul rounds inputs to BF16, so error is proportional to the max input
// magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
// Dot() uses double-precision summation.
double tolerance = 20 * norm * eps_f32;
// If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the
// F32 to BF16, so add extra tolerance.
if (IsF32<TA>() || IsF32<TB>()) {
tolerance += 2 * max_abs * eps_bf16;
}
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double rel_tolerance =
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
double max_rel = 0.0;
size_t worst_r = 0;
size_t worst_c = 0;
double worst_actual = 0.0;
double worst_expected = 0.0;
size_t num_outside = 0;
for (size_t r = 0; r < A.Rows(); r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
for (size_t c = 0; c < B.Rows(); c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
if (!in_range) {
const double max = HWY_MAX(expected_value, actual_value);
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
max_rel = rel;
++num_outside;
}
}
}
}
if (max_rel > rel_tolerance) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E num_outside %zu\n",
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
tolerance, max_rel, rel_tolerance, num_outside);
}
}
// B is already transposed.
template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
@ -211,14 +95,6 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
});
}
void PrintSpeed(const char* algo, const Extents2D& A_extents,
const Extents2D& B_extents, double elapsed) {
const size_t num_b = B_extents.Area();
// 2x because of FMA.
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
}
template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) {
@ -257,7 +133,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MMOptions options;
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
AssertClose(A, BT, C_slow, C, env, line);
AssertClose(A, BT, C_slow, C, env.ctx.allocator, env.row_ptrs, line);
// Check before TwoMatMulStatic(), which can invalidate per_key.
const bool autotune_done = !!per_key->autotune.Best();
@ -295,7 +171,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
// TwoMatMulStatic() does not support adding a bias vector.
if (!add) {
AssertClose(A, BT, C, C2, env, line);
AssertClose(A, BT, C, C2, env.ctx.allocator, env.row_ptrs, line);
}
}

View File

@ -65,7 +65,6 @@ cc_test(
"//:benchmark_helper",
"//:configs",
"//:gemma_lib",
"//io",
"@highway//:hwy_test_util",
],
)

View File

@ -469,6 +469,38 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
}
}
// Calls 'func' with a span of MatPtrT<T> for all elements in `base`.
// T is dynamic type, read from base. It is assumed that all elements in `base`
// have the same type.
template <class Func, typename... Args>
decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
Args&&... args) {
Type type = base[0].GetType();
for ([[maybe_unused]] auto&& mat : base) {
HWY_DASSERT(mat.GetType() == type);
}
auto convert_to_matptr_t = [&base]<typename T>() {
std::vector<MatPtrT<T>> matptrs;
matptrs.reserve(base.size());
for (auto&& mat : base) {
matptrs.emplace_back(mat);
}
return matptrs;
};
if (type == Type::kF32) {
auto matptrs = convert_to_matptr_t.template operator()<float>();
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);
} else if (type == Type::kBF16) {
auto matptrs = convert_to_matptr_t.template operator()<BF16>();
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(type));
}
}
void CopyMat(const MatPtr& from, MatPtr& to);
void ZeroInit(MatPtr& mat);

View File

@ -115,7 +115,7 @@ template <typename T>
void PrintMatPtr(MatPtrT<T> mat) {
for (int i = 0; i < mat.Rows(); ++i) {
for (int j = 0; j < mat.Cols(); ++j) {
std::cerr << mat.Row(i)[j] << " ,";
std::cerr << hwy::ConvertScalarTo<float>(mat.Row(i)[j]) << " ,";
}
std::cerr << std::endl;
}

View File

@ -10,7 +10,7 @@
namespace gcpp {
// Zones for the profiler.
enum class Zones { // Keep sorted
enum class Zones { // Keep sorted
kFlashAttentionFlashAttention,
kFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding,