mirror of https://github.com/google/gemma.cpp.git
Replace RowVectorBatch with MatStorageT
KVCache: add ctor required for MatStorageT, remove Create; bf_pre_ffw_rms_out -> pre_ffw_rms_out optimize_test: larger vocab_size requires more steps shared.h: Remove unused u128 type correctly set Activation matrix rows, avoid passing as arg ops: pass Mat instead of pointers/sizes; vectorize LayerNorm; support any weight type mat: add OverrideRows, used by SetBatchSize PiperOrigin-RevId: 757790736
This commit is contained in:
parent
cf7dd80c17
commit
45ad847a41
|
|
@ -415,6 +415,7 @@ cc_library(
|
|||
hdrs = ["gemma/kv_cache.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
":mat",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
@ -425,6 +426,7 @@ cc_library(
|
|||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":mat",
|
||||
":ops", # matmul.h
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ struct ForwardLayer {
|
|||
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
|
||||
attention_out(
|
||||
MakePacked<T>("attention_out", seq_len, config.model_dim)),
|
||||
bf_pre_ffw_rms_out(
|
||||
MakePacked<T>("bf_preFF_rms_out", seq_len, config.model_dim)),
|
||||
pre_ffw_rms_out(
|
||||
MakePacked<T>("preFF_rms_out", seq_len, config.model_dim)),
|
||||
ffw_hidden(
|
||||
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
|
||||
ffw_hidden_gated(
|
||||
|
|
@ -53,7 +53,7 @@ struct ForwardLayer {
|
|||
MatStorageT<T> att_out;
|
||||
MatStorageT<T> att_post1;
|
||||
MatStorageT<T> attention_out;
|
||||
MatStorageT<T> bf_pre_ffw_rms_out;
|
||||
MatStorageT<T> pre_ffw_rms_out;
|
||||
MatStorageT<T> ffw_hidden;
|
||||
MatStorageT<T> ffw_hidden_gated;
|
||||
const LayerConfig& layer_config;
|
||||
|
|
|
|||
|
|
@ -170,8 +170,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
const ForwardLayer<float>& forward,
|
||||
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const MatStorageT<float>& inv_timescale, hwy::ThreadPool& pool) {
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t qkv_dim = config.qkv_dim;
|
||||
|
|
@ -207,15 +206,14 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.gating_einsum_w.Packed(),
|
||||
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
|
||||
model_dim, ff_hidden_dim * 2, num_tokens,
|
||||
grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(),
|
||||
pool);
|
||||
RMSNormVJP(
|
||||
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
|
||||
backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool);
|
||||
MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
|
||||
backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2,
|
||||
num_tokens, grad.gating_einsum_w.Packed(),
|
||||
backward.pre_ffw_rms_out.Packed(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.Packed(),
|
||||
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
|
||||
model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(),
|
||||
backward.attention_out.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(next_layer_grad + pos * model_dim,
|
||||
|
|
@ -275,7 +273,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||
float* HWY_RESTRICT b_kv =
|
||||
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
|
||||
Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos);
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
|
|
@ -283,7 +281,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
float* HWY_RESTRICT b_q =
|
||||
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
|
||||
MulByConst(query_scale, b_q, qkv_dim);
|
||||
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
|
||||
Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -342,7 +340,7 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
|
|||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t kVocabSize = config.vocab_size;
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ void CrossEntropyLossBackwardPassT(const Prompt& prompt,
|
|||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
|
||||
inv_timescale, pool);
|
||||
|
|
@ -62,7 +62,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -218,16 +218,15 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulVJPT(weights.gating_einsum_w.Packed(),
|
||||
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
|
||||
grad.gating_einsum_w.Packed(),
|
||||
backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
|
||||
MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
|
||||
backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(),
|
||||
backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(
|
||||
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
|
||||
backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(),
|
||||
backward.attention_out.Packed(), model_dim, num_tokens);
|
||||
RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(),
|
||||
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
|
||||
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);
|
||||
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ void TestEndToEnd() {
|
|||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
|
|||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
RMSNorm(x + offset, weights, output + offset, model_dim);
|
||||
RMSNorm(x + offset, weights, 0, output + offset, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ template <typename T>
|
|||
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<float>& activations, size_t num_tokens,
|
||||
float* HWY_RESTRICT output,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
const MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
|
|
@ -125,14 +125,14 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(k, kQKVDim, inv_timescale.Const(), pos);
|
||||
Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos);
|
||||
}
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
||||
Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos);
|
||||
MulByConst(query_scale, q, kQKVDim);
|
||||
});
|
||||
|
||||
|
|
@ -194,11 +194,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
|
||||
activations.attention_out.Packed(), model_dim, num_tokens,
|
||||
activations.bf_pre_ffw_rms_out.Packed(), pool);
|
||||
activations.pre_ffw_rms_out.Packed(), pool);
|
||||
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
||||
activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim,
|
||||
activations.pre_ffw_rms_out.Packed() + pos * model_dim,
|
||||
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
|
|
@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
|||
size_t context_size,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
const MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ namespace HWY_NAMESPACE {
|
|||
float CrossEntropyLossForwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
|
||||
weights, forward, inv_timescale, pool);
|
||||
|
|
@ -56,7 +56,7 @@ HWY_EXPORT(CrossEntropyLossForwardPassT);
|
|||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
||||
prompt, weights, forward, inv_timescale, pool);
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -27,7 +27,7 @@ namespace gcpp {
|
|||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -219,12 +219,11 @@ void ApplyLayer(const LayerWeightsPtrs<T>& weights,
|
|||
|
||||
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
|
||||
activations.attention_out.Packed(),
|
||||
activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens);
|
||||
activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.gating_einsum_w.Packed(),
|
||||
activations.bf_pre_ffw_rms_out.Packed(),
|
||||
activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim,
|
||||
num_tokens);
|
||||
activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(),
|
||||
ff_hidden_dim * 2, model_dim, num_tokens);
|
||||
|
||||
GatedGelu(activations.ffw_hidden.Packed(),
|
||||
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);
|
||||
|
|
|
|||
|
|
@ -62,9 +62,9 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
grad_m.ZeroInit();
|
||||
grad_v.ZeroInit();
|
||||
ForwardPass<float> forward(config), backward(config);
|
||||
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
||||
KVCache kv_cache(config, /*prefill_tbatch_size=*/16);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||
allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
|
||||
|
|
@ -147,7 +147,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
gemma.MutableWeights().LogWeightStatsF32();
|
||||
EXPECT_LT(steps, 50);
|
||||
EXPECT_LT(steps, 80);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::shuffle
|
||||
#include <array>
|
||||
#include <random>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
|
|
@ -104,7 +105,7 @@ struct TestPlateaus {
|
|||
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
|
|
@ -151,7 +152,7 @@ struct TestRamp {
|
|||
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
|
|
@ -246,7 +247,8 @@ struct TestOffset {
|
|||
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
|
||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
||||
|
|
@ -296,7 +298,8 @@ struct TestUnalignedOffset {
|
|||
|
||||
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
|
||||
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
|
||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
|
@ -347,7 +350,8 @@ struct TestDec2 {
|
|||
auto dec0 = hwy::AllocateAligned<T>(total);
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
|
||||
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
||||
|
|
@ -449,7 +453,8 @@ struct TestEncDec {
|
|||
const size_t num = 4 * kGroupSize;
|
||||
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
|
||||
auto out = hwy::AllocateAligned<T>(num); // already padded
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes()));
|
||||
HWY_ASSERT(in && out && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), num);
|
||||
|
||||
|
|
|
|||
|
|
@ -164,11 +164,11 @@ constexpr bool IsNuqStream() {
|
|||
// weights for a model, but can be used for other purposes, such as types for
|
||||
// `WeightsPtrs`. When adding a new type that is supported, also
|
||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 };
|
||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||
// change, though new ones may be added.
|
||||
static constexpr const char* kTypeStrings[] = {
|
||||
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"};
|
||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "c64"};
|
||||
static constexpr size_t kNumTypes =
|
||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||
static constexpr size_t kTypeBits[] = {0,
|
||||
|
|
@ -177,8 +177,7 @@ static constexpr size_t kTypeBits[] = {0,
|
|||
8 * sizeof(SfpStream),
|
||||
4 /* NuqStream, actually 4.5 */,
|
||||
8 * sizeof(double),
|
||||
8 * sizeof(std::complex<double>),
|
||||
8 * sizeof(hwy::uint128_t)};
|
||||
8 * sizeof(std::complex<double>)};
|
||||
|
||||
static inline bool EnumValid(Type type) {
|
||||
return static_cast<size_t>(type) < kNumTypes;
|
||||
|
|
@ -200,8 +199,6 @@ Type TypeEnum() {
|
|||
return Type::kF64;
|
||||
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
|
||||
return Type::kC64;
|
||||
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
|
||||
return Type::kU128;
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
return Type::kUnknown;
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(),
|
||||
KVCache kv_cache(env.GetGemma()->GetModelConfig(),
|
||||
env.MutableConfig().prefill_tbatch_size);
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
|
|
|
|||
|
|
@ -52,9 +52,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader,
|
|||
const InferenceArgs& inference)
|
||||
: env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) {
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] =
|
||||
KVCache::Create(gemma_.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
kv_caches_.push_back(
|
||||
KVCache(gemma_.GetModelConfig(), inference.prefill_tbatch_size));
|
||||
|
||||
InitGenerator(inference, gen_);
|
||||
|
||||
|
|
@ -131,15 +130,10 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
runtime_config_.decode_qbatch_size);
|
||||
}
|
||||
|
||||
// Ensure we have one KVCache per query.
|
||||
if (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.resize(num_queries);
|
||||
}
|
||||
for (size_t i = 1; i < num_queries; ++i) {
|
||||
if (kv_caches_[i].seq_len == 0) {
|
||||
kv_caches_[i] = KVCache::Create(gemma_.GetModelConfig(),
|
||||
runtime_config_.prefill_tbatch_size);
|
||||
}
|
||||
// Ensure we have at least one KVCache per query.
|
||||
while (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.push_back(
|
||||
KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size));
|
||||
}
|
||||
|
||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
|
|
|
|||
|
|
@ -53,8 +53,7 @@ int main(int argc, char** argv) {
|
|||
// Instantiate model and KV Cache
|
||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||
gcpp::Gemma gemma(loader, env);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(gemma.GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
|
|||
|
|
@ -39,11 +39,8 @@ class SimplifiedGemma {
|
|||
threading_(threading),
|
||||
inference_(inference),
|
||||
env_(MakeMatMulEnv(threading_)),
|
||||
gemma_(loader_, env_) {
|
||||
// Instantiate model and KV Cache
|
||||
kv_cache_ = gcpp::KVCache::Create(gemma_.GetModelConfig(),
|
||||
inference_.prefill_tbatch_size);
|
||||
|
||||
gemma_(loader_, env_),
|
||||
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
gen_.seed(rd());
|
||||
|
|
|
|||
|
|
@ -23,106 +23,127 @@
|
|||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/allocator.h" // Allocator
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct Activations {
|
||||
explicit Activations(const ModelConfig& config)
|
||||
Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
seq_len(config.seq_len),
|
||||
cache_pos_size(config.CachePosSize()) {}
|
||||
cache_pos_size(config.CachePosSize()),
|
||||
is_griffin(layer_config.type ==
|
||||
LayerAttentionType::kGriffinRecurrentBlock),
|
||||
|
||||
RowVectorBatch<float> x; // input
|
||||
RowVectorBatch<float> q; // query, also KV if MHA.
|
||||
RowVectorBatch<float> logits;
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
q("q",
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.QStride()),
|
||||
pad_),
|
||||
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
|
||||
|
||||
// Attention
|
||||
RowVectorBatch<float> pre_att_rms_out;
|
||||
RowVectorBatch<float> att; // attention vector
|
||||
RowVectorBatch<float> att_out; // attention output
|
||||
// Accumulation of attention outputs over heads
|
||||
RowVectorBatch<float> att_sums;
|
||||
pre_att_rms_out("pre_att_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
att("att", Extents2D(batch_size, layer_config.heads * config.seq_len),
|
||||
pad_),
|
||||
att_out(
|
||||
"att_out",
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
||||
pad_),
|
||||
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_),
|
||||
|
||||
// Gated FFW
|
||||
RowVectorBatch<BF16> bf_pre_ffw_rms_out;
|
||||
RowVectorBatch<float> C1;
|
||||
RowVectorBatch<float> C2;
|
||||
RowVectorBatch<float> ffw_out;
|
||||
pre_ffw_rms_out("pre_ffw_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
||||
|
||||
// Griffin
|
||||
RowVectorBatch<float> griffin_x;
|
||||
RowVectorBatch<float> griffin_y;
|
||||
RowVectorBatch<float> griffin_gate_x;
|
||||
RowVectorBatch<float> griffin_multiplier;
|
||||
// No padding for Griffin because it does not always use Row().
|
||||
griffin_x("griffin_x",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
griffin_y("griffin_y",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
griffin_gate_x(
|
||||
"griffin_gate_x",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
griffin_multiplier(
|
||||
"griffin_mul",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
|
||||
// Rope
|
||||
RowVectorBatch<float> inv_timescale;
|
||||
RowVectorBatch<float> inv_timescale_global;
|
||||
inv_timescale(
|
||||
CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope)),
|
||||
inv_timescale_global(CreateInvTimescale(
|
||||
env->ctx.allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
||||
|
||||
// Dynamic because no default ctor and only initialized in `Allocate`.
|
||||
MatMulEnv* env;
|
||||
env(env) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
x.OverrideRows(batch_size);
|
||||
q.OverrideRows(batch_size);
|
||||
logits.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
|
||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||
C1.OverrideRows(batch_size);
|
||||
C2.OverrideRows(batch_size);
|
||||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
if (is_griffin) {
|
||||
griffin_x.OverrideRows(batch_size);
|
||||
griffin_y.OverrideRows(batch_size);
|
||||
griffin_gate_x.OverrideRows(batch_size);
|
||||
griffin_multiplier.OverrideRows(batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
// And the config.
|
||||
const ModelConfig& weights_config;
|
||||
const LayerConfig& layer_config;
|
||||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
|
||||
bool is_griffin = false;
|
||||
const Extents2D none_ = Extents2D();
|
||||
const MatPadding pad_ = MatPadding::kOdd;
|
||||
|
||||
void Allocate(size_t batch_size, MatMulEnv* env) {
|
||||
const Allocator& allocator = env->ctx.allocator;
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<float> q; // query, also KV if MHA.
|
||||
MatStorageT<float> logits;
|
||||
|
||||
post_qk = layer_config.post_qk;
|
||||
const size_t model_dim = weights_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
const size_t vocab_size = weights_config.vocab_size;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
// Attention
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
// Accumulation of attention outputs over heads
|
||||
MatStorageT<float> att_sums;
|
||||
|
||||
x = RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
q = RowVectorBatch<float>(
|
||||
allocator, Extents2D(batch_size, heads * layer_config.QStride()));
|
||||
if (vocab_size > 0) {
|
||||
logits =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, vocab_size));
|
||||
}
|
||||
// Gated FFW
|
||||
MatStorageT<BF16> pre_ffw_rms_out;
|
||||
MatStorageT<float> C1;
|
||||
MatStorageT<float> C2;
|
||||
MatStorageT<float> ffw_out;
|
||||
|
||||
pre_att_rms_out =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
att = RowVectorBatch<float>(
|
||||
allocator, Extents2D(batch_size, heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(allocator,
|
||||
Extents2D(batch_size, heads * qkv_dim));
|
||||
att_sums =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
// Griffin
|
||||
MatStorageT<float> griffin_x;
|
||||
MatStorageT<float> griffin_y;
|
||||
MatStorageT<float> griffin_gate_x;
|
||||
MatStorageT<float> griffin_multiplier;
|
||||
|
||||
bf_pre_ffw_rms_out =
|
||||
RowVectorBatch<BF16>(allocator, Extents2D(batch_size, model_dim));
|
||||
C1 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
|
||||
C2 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
|
||||
ffw_out =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
griffin_x =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
griffin_y =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
griffin_gate_x =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
griffin_multiplier =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim,
|
||||
post_qk == PostQKType::HalfRope);
|
||||
inv_timescale_global = CreateInvTimescale(
|
||||
allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||
|
||||
this->env = env;
|
||||
}
|
||||
MatMulEnv* env;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ ConversationData::ConversationData(const ModelConfig& model_config,
|
|||
size_t prefill_tbatch_size)
|
||||
: model_config_ref_(model_config),
|
||||
prefill_tbatch_size_(prefill_tbatch_size),
|
||||
kv_cache(std::make_unique<KVCache>(
|
||||
KVCache::Create(model_config, prefill_tbatch_size))),
|
||||
kv_cache(std::make_unique<KVCache>(model_config, prefill_tbatch_size)),
|
||||
abs_pos(0) {}
|
||||
|
||||
// ConversationData copy constructor implementation
|
||||
|
|
@ -184,25 +183,28 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
inference_args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
const ModelConfig& model_config = model.GetModelConfig();
|
||||
|
||||
// generate
|
||||
std::vector<int> prompt;
|
||||
ImageTokens image_tokens;
|
||||
const size_t pool_dim = model_config.vit_config.pool_dim;
|
||||
ImageTokens image_tokens(
|
||||
"image_tokens",
|
||||
image_data
|
||||
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
|
||||
model_config.model_dim)
|
||||
: Extents2D(0, 0),
|
||||
MatPadding::kOdd);
|
||||
if (image_data != nullptr) {
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens =
|
||||
ImageTokens(model.Env().ctx.allocator,
|
||||
Extents2D(model.GetModelConfig().vit_config.seq_len /
|
||||
(pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model_config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
|
||||
Image image;
|
||||
image.Set(image_width, image_height, static_cast<const float*>(image_data));
|
||||
|
||||
// We may need to resize the supplied image depending on whether we're using
|
||||
// PaliGemma or Gemma 3.
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
const size_t image_size = model_config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
|
||||
// Use the existing runtime_config defined earlier in the function.
|
||||
|
|
@ -217,10 +219,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.GetModelConfig().wrapping,
|
||||
active_conversation->abs_pos, prompt_string,
|
||||
image_tokens.BatchSize());
|
||||
prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,
|
||||
active_conversation->abs_pos, prompt_string, image_tokens.Rows());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
|
|
@ -230,7 +231,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
// Text-only case (original logic)
|
||||
// Use abs_pos from the active conversation
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.GetModelConfig().wrapping,
|
||||
model_config.wrapping,
|
||||
active_conversation->abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
|
@ -251,7 +252,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
active_conversation->abs_pos = 0;
|
||||
|
|
|
|||
|
|
@ -188,8 +188,8 @@ class GemmaContext {
|
|||
// rewind to initial state.
|
||||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -89,11 +89,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
|
||||
// X / Y linear layers.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
||||
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
||||
activations.pre_att_rms_out.Batch(batch_idx),
|
||||
activations.pre_att_rms_out.Row(batch_idx),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
|
|
@ -103,17 +103,16 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
// Conv1D.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
||||
const size_t layer_offset = layer * model_dim * (conv_1d_width - 1);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||
cache[0] = x;
|
||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||
cache[i] =
|
||||
kv_cache.conv1d_cache.get() + layer_offset +
|
||||
kv_cache.conv1d_cache.Row(layer) +
|
||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||
}
|
||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||
|
|
@ -140,12 +139,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
// RGLRU
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx);
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_cache.rglru_cache.get() + layer * model_dim;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
|
||||
float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer);
|
||||
|
||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t kHeadDim = model_dim / heads;
|
||||
|
|
@ -193,8 +191,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
|
||||
// Final linear layer.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||
float* out_ptr = activations.att_sums.Batch(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
float* out_ptr = activations.att_sums.Row(batch_idx);
|
||||
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
||||
pool);
|
||||
|
|
@ -217,7 +215,7 @@ class GemmaAttention {
|
|||
const float mul) {
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||
const float* inv_timescale = activations_.inv_timescale.Packed();
|
||||
bool is_global_layer =
|
||||
activations_.weights_config.attention_window_sizes[layer] ==
|
||||
activations_.seq_len;
|
||||
|
|
@ -227,7 +225,7 @@ class GemmaAttention {
|
|||
activations_.weights_config.model == Model::GEMMA3_12B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_27B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_1B)) {
|
||||
inv_timescale = activations_.inv_timescale_global.Const();
|
||||
inv_timescale = activations_.inv_timescale_global.Packed();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
(void)layer;
|
||||
|
|
@ -249,11 +247,10 @@ class GemmaAttention {
|
|||
const size_t heads = layer_config_.heads;
|
||||
const size_t kv_heads = layer_config_.kv_heads;
|
||||
|
||||
const auto pre_att_rms_out =
|
||||
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
|
||||
auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
|
||||
using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T;
|
||||
ConstMat<WeightT> w_q1(layer_weights_.qkv_einsum_w.HasPtr()
|
||||
? layer_weights_.qkv_einsum_w
|
||||
: layer_weights_.qkv_einsum_w1);
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
|
||||
// We must shrink to the actual size because MatMul verifies
|
||||
|
|
@ -262,20 +259,19 @@ class GemmaAttention {
|
|||
// computed in the second MatMul.
|
||||
const size_t w1_rows = heads * layer_config_.QStride();
|
||||
w_q1.ShrinkRows(w1_rows);
|
||||
MatMul(pre_att_rms_out, w_q1,
|
||||
MatMul(activations_.pre_att_rms_out, w_q1,
|
||||
/*add=*/nullptr, *activations_.env,
|
||||
RowPtrFromBatch(allocator_, activations_.q));
|
||||
RowPtrFromMat(allocator_, activations_.q));
|
||||
|
||||
if (is_mha_) {
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||
} else {
|
||||
decltype(w_q1) w_q2;
|
||||
decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr()
|
||||
? layer_weights_.qkv_einsum_w
|
||||
: layer_weights_.qkv_einsum_w2);
|
||||
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
||||
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w);
|
||||
// Skip first half of the matrix.
|
||||
w_q2.ofs = w_q2.Row(w1_rows);
|
||||
} else {
|
||||
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
||||
}
|
||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||
|
|
@ -290,13 +286,13 @@ class GemmaAttention {
|
|||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
|
||||
kv_rows.SetStride(cache_pos_size_);
|
||||
MatMul(pre_att_rms_out, w_q2,
|
||||
MatMul(activations_.pre_att_rms_out, w_q2,
|
||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||
} else {
|
||||
// Proceed row by row because there will be wraparound.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
|
||||
const float* x = activations_.pre_att_rms_out.Row(interleaved_idx);
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
|
|
@ -327,15 +323,15 @@ class GemmaAttention {
|
|||
// If MHA, copy computed K and V into KVCache.
|
||||
if (is_mha_) {
|
||||
const float* HWY_RESTRICT mha_kv =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||
activations_.q.Row(interleaved_idx) + head * q_stride_ +
|
||||
qkv_dim;
|
||||
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
|
||||
}
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer_weights_.key_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv,
|
||||
qkv_dim);
|
||||
RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(),
|
||||
0, kv, qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||
});
|
||||
|
|
@ -402,7 +398,8 @@ class GemmaAttention {
|
|||
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim);
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q,
|
||||
qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
||||
|
|
@ -435,13 +432,12 @@ class GemmaAttention {
|
|||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||
activations_.q.Row(interleaved_idx) + head * q_stride_;
|
||||
float* HWY_RESTRICT att =
|
||||
activations_.att.Batch(interleaved_idx) +
|
||||
activations_.att.Row(interleaved_idx) +
|
||||
head * activations_.seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(interleaved_idx) +
|
||||
head * qkv_dim;
|
||||
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
|
||||
|
||||
// Make strided views into the kv cache entries for the current
|
||||
// query and head.
|
||||
|
|
@ -476,28 +472,25 @@ class GemmaAttention {
|
|||
private:
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
||||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config_.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config_.model_dim] matmul output is the sum over heads. Compare
|
||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||
// encoded)
|
||||
HWY_DASSERT(layer_config_.model_dim > 0);
|
||||
HWY_DASSERT(layer_config_.heads > 0);
|
||||
HWY_DASSERT(layer_config_.qkv_dim > 0);
|
||||
HWY_DASSERT(layer_config_.model_dim != 0 && layer_config_.heads != 0 &&
|
||||
layer_config_.qkv_dim != 0);
|
||||
HWY_DASSERT(layer_weights_.att_weights.HasPtr());
|
||||
HWY_DASSERT(activations_.att_out.All() != nullptr);
|
||||
HWY_DASSERT(activations_.att_sums.All() != nullptr);
|
||||
HWY_DASSERT(activations_.att_out.HasPtr());
|
||||
HWY_DASSERT(activations_.att_sums.HasPtr());
|
||||
|
||||
const float* add =
|
||||
layer_weights_.layer_config.softmax_attn_output_biases
|
||||
? layer_weights_.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
||||
ConstMatFromWeights(layer_weights_.att_weights), add,
|
||||
*activations_.env,
|
||||
RowPtrFromBatch(allocator_, activations_.att_sums));
|
||||
MatMul(activations_.att_out, layer_weights_.att_weights, add,
|
||||
*activations_.env, RowPtrFromMat(allocator_, activations_.att_sums));
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -524,7 +517,7 @@ class GemmaAttention {
|
|||
const size_t num_interleaved = num_tokens_ * num_queries_;
|
||||
ComputeQKV(num_interleaved);
|
||||
DotSoftmaxWeightedSum(num_interleaved);
|
||||
SumHeads(num_interleaved);
|
||||
SumHeads();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -618,12 +611,11 @@ class VitAttention {
|
|||
HWY_NOINLINE void ComputeQKV() {
|
||||
PROFILER_ZONE("Gen.VitAttention.QKV");
|
||||
auto& qkv = activations_.q;
|
||||
HWY_ASSERT(qkv.BatchSize() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
||||
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
||||
MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
||||
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
|
||||
RowPtrFromBatch(allocator_, qkv));
|
||||
RowPtrFromMat(allocator_, qkv));
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
|
|
@ -635,52 +627,49 @@ class VitAttention {
|
|||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
|
||||
RowVectorBatch<float> Q =
|
||||
AllocateAlignedRows<float>(allocator_, Extents2D(num_tokens_, qkv_dim));
|
||||
RowVectorBatch<float> K =
|
||||
AllocateAlignedRows<float>(allocator_, Extents2D(seq_len, qkv_dim));
|
||||
RowVectorBatch<float> C(allocator_, Extents2D(num_tokens_, seq_len));
|
||||
// Shift Q, K, VT to MatStorageT.
|
||||
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
|
||||
MatPadding::kPacked);
|
||||
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim),
|
||||
MatPadding::kPacked);
|
||||
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
|
||||
MatPadding::kPacked);
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
hwy::ZeroBytes(activations_.att_out.All(),
|
||||
num_tokens_ * heads * qkv_dim * sizeof(float));
|
||||
ZeroInit(activations_.att_out);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t token = task;
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
||||
float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float));
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t seq_idx = task;
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float));
|
||||
activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
|
||||
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
|
||||
RowPtrFromBatch(allocator_, C));
|
||||
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C));
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Batch(task);
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
Softmax(c, C.Cols());
|
||||
});
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
size_t token = task;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(token) + head * qkv_dim;
|
||||
activations_.att_out.Row(token) + head * qkv_dim;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v =
|
||||
activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim);
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -701,24 +690,24 @@ class VitAttention {
|
|||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
||||
activations_.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(token) + head * activations_.seq_len;
|
||||
activations_.att.Row(token) + head * activations_.seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(token) + head * qkv_dim;
|
||||
activations_.att_out.Row(token) + head * qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.q.Batch(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
float* HWY_RESTRICT v =
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
|
|
@ -732,10 +721,9 @@ class VitAttention {
|
|||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||
auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums);
|
||||
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
|
||||
auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums);
|
||||
MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
||||
*activations_.env, att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -771,7 +759,7 @@ class VitAttention {
|
|||
|
||||
template <typename T>
|
||||
HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||
T* HWY_RESTRICT c2, size_t count) {
|
||||
const T* HWY_RESTRICT c2, size_t count) {
|
||||
PROFILER_ZONE("Gen.Activation");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<T>;
|
||||
|
|
@ -787,12 +775,38 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
|||
});
|
||||
}
|
||||
|
||||
// No C2 multiplier.
|
||||
template <class Mat>
|
||||
void ActivationBatched(ActivationType activation, Mat& c1) {
|
||||
using T = typename Mat::T;
|
||||
for (size_t i = 0; i < c1.Rows(); ++i) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
|
||||
c1.Cols());
|
||||
}
|
||||
}
|
||||
|
||||
template <class Mat>
|
||||
void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) {
|
||||
using T = typename Mat::T;
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
for (size_t i = 0; i < c1.Rows(); ++i) {
|
||||
Activation(activation, c1.Row(i), c2->Row(i), c1.Cols());
|
||||
}
|
||||
} else { // No multiplier
|
||||
for (size_t i = 0; i < c1.Rows(); ++i) {
|
||||
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
|
||||
c1.Cols());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
||||
HWY_NOINLINE void FFWNoVit(Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
|
|
@ -802,56 +816,48 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
|||
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
const Allocator& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
|
||||
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
|
||||
auto multiplier = RowPtrFromMat(allocator, activations.C2);
|
||||
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
|
||||
|
||||
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
|
||||
|
||||
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
||||
// avoid this confusion by splitting into gating_einsum_w1 and
|
||||
// gating_einsum_w2.
|
||||
// gating_einsum_w2. TODO: move into Reshape().
|
||||
const bool split = layer_weights->gating_einsum_w.HasPtr();
|
||||
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
|
||||
decltype(w1) w2;
|
||||
ConstMat<WeightT> w1(split ? layer_weights->gating_einsum_w
|
||||
: layer_weights->gating_einsum_w1);
|
||||
ConstMat<WeightT> w2(split ? layer_weights->gating_einsum_w
|
||||
: layer_weights->gating_einsum_w2);
|
||||
if (split) {
|
||||
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w);
|
||||
w2.ofs = w2.Row(ffh_hidden_dim);
|
||||
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
||||
w1.ShrinkRows(ffh_hidden_dim);
|
||||
w2.ShrinkRows(ffh_hidden_dim);
|
||||
} else {
|
||||
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
||||
}
|
||||
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
||||
MatMul(x, w2, bias2, *activations.env, multiplier);
|
||||
MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env,
|
||||
hidden_activations);
|
||||
MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ffh_hidden_dim * num_interleaved);
|
||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
|
||||
&activations.C2);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
auto activations_mat = MakeConstMat(
|
||||
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim),
|
||||
hidden_activations.Stride());
|
||||
|
||||
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
||||
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
|
||||
ffw_out);
|
||||
}
|
||||
|
||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||
// gating matrix.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||
HWY_NOINLINE void FFWVit(Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
PROFILER_ZONE("Gen.FFW.ViT");
|
||||
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
|
|
@ -860,30 +866,21 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
|||
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
const Allocator& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||
|
||||
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
|
||||
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
||||
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
|
||||
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
||||
MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
|
||||
*activations.env, hidden_activations);
|
||||
|
||||
// Activation (Gelu), store in act.
|
||||
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
auto activations_mat = MakeConstMat(hidden_activations.Row(0),
|
||||
Extents2D(num_interleaved, ff_hidden_dim),
|
||||
hidden_activations.Stride());
|
||||
|
||||
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
||||
MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
||||
*activations.env, ffw_out);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -898,23 +895,23 @@ template <typename T>
|
|||
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
RowVectorBatch<float>& x,
|
||||
MatStorageT<float>& x,
|
||||
const ImageTokens* image_tokens,
|
||||
size_t& image_token_position) {
|
||||
// Image tokens just need to be copied.
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
image_token_position < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(image_token_position),
|
||||
x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0]));
|
||||
image_token_position < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
|
||||
x.Cols() * x.ElementBytes());
|
||||
image_token_position++;
|
||||
return;
|
||||
}
|
||||
|
||||
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
||||
x.Cols() * sizeof(x.Const()[0]));
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -934,12 +931,12 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
|||
HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim);
|
||||
const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0),
|
||||
embedding_ofs + model_dim);
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Batch(batch_idx),
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
|
||||
x.Batch(batch_idx), model_dim);
|
||||
x.Row(batch_idx), model_dim);
|
||||
if (weights.weights_config.absolute_pe) {
|
||||
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos);
|
||||
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -951,29 +948,28 @@ template <typename T>
|
|||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
RowVectorBatch<float>& x,
|
||||
MatStorageT<float>& x,
|
||||
const ImageTokens* image_tokens) {
|
||||
size_t image_token_position = 0;
|
||||
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
|
||||
image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
template <typename Weights, typename T>
|
||||
HWY_NOINLINE void ResidualConnection(
|
||||
size_t num_interleaved, const T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
||||
const LayerWeightsPtrs<Weights>* layer_weights, bool is_attention) {
|
||||
template <typename T2, class LayerWeights>
|
||||
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
||||
MatPtrT<float>& HWY_RESTRICT x,
|
||||
const LayerWeights* layer_weights,
|
||||
bool is_attention) {
|
||||
// ResidualType::Add
|
||||
AddFromBatched(num_interleaved, other, x,
|
||||
layer_weights->layer_config.model_dim);
|
||||
AddFromBatched(other, x);
|
||||
}
|
||||
|
||||
template <typename WeightT, typename InOutT>
|
||||
void PostNorm(PostNormType post_norm, size_t num_interleaved,
|
||||
const WeightT& weights, InOutT* inout) {
|
||||
void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
|
||||
MatPtrT<InOutT>& inout) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
if (post_norm == PostNormType::Scale) {
|
||||
RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout,
|
||||
weights.Cols());
|
||||
RMSNormInplaceBatched(weights, inout);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -985,39 +981,33 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
|||
Activations& activations,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
const size_t model_dim = activations.weights_config.model_dim;
|
||||
const size_t num_interleaved = num_tokens * queries_pos.size();
|
||||
auto type = layer_weights->layer_config.type;
|
||||
|
||||
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||
layer_weights->pre_attention_norm_scale.PackedScale1(),
|
||||
activations.pre_att_rms_out.All(), model_dim);
|
||||
RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale,
|
||||
activations.pre_att_rms_out);
|
||||
|
||||
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
|
||||
activations, layer_weights, div_seq_len, kv_caches);
|
||||
|
||||
PostNorm(layer_weights->layer_config.post_norm, num_interleaved,
|
||||
layer_weights->post_attention_norm_scale,
|
||||
activations.att_sums.All());
|
||||
PostNorm(layer_weights->layer_config.post_norm,
|
||||
layer_weights->post_attention_norm_scale, activations.att_sums);
|
||||
|
||||
ResidualConnection(num_interleaved, activations.att_sums.All(),
|
||||
activations.x.All(), layer_weights, /*is_attention=*/true);
|
||||
ResidualConnection(activations.att_sums, activations.x, layer_weights,
|
||||
/*is_attention=*/true);
|
||||
|
||||
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||
layer_weights->pre_ffw_norm_scale.PackedScale1(),
|
||||
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
RMSNormBatched(activations.x, layer_weights->pre_ffw_norm_scale,
|
||||
activations.pre_ffw_rms_out);
|
||||
|
||||
if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
|
||||
FFWVit(activations, num_interleaved, layer_weights);
|
||||
FFWVit(activations, layer_weights);
|
||||
} else {
|
||||
FFWNoVit(activations, num_interleaved, layer_weights);
|
||||
FFWNoVit(activations, layer_weights);
|
||||
}
|
||||
|
||||
PostNorm(layer_weights->layer_config.post_norm, num_interleaved,
|
||||
layer_weights->post_ffw_norm_scale, activations.ffw_out.All());
|
||||
PostNorm(layer_weights->layer_config.post_norm,
|
||||
layer_weights->post_ffw_norm_scale, activations.ffw_out);
|
||||
|
||||
ResidualConnection(num_interleaved, activations.ffw_out.All(),
|
||||
activations.x.All(), layer_weights,
|
||||
ResidualConnection(activations.ffw_out, activations.x, layer_weights,
|
||||
/*is_attention=*/false);
|
||||
}
|
||||
|
||||
|
|
@ -1034,38 +1024,37 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
auto type = layer_weights->layer_config.type;
|
||||
HWY_DASSERT(type == LayerAttentionType::kVit);
|
||||
(void)type;
|
||||
(void)model_dim;
|
||||
|
||||
auto& x = activations.x;
|
||||
HWY_DASSERT(x.BatchSize() == num_tokens);
|
||||
HWY_DASSERT(x.Rows() == num_tokens);
|
||||
HWY_DASSERT(x.Cols() == model_dim);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_att_rms_out
|
||||
LayerNormBatched(num_tokens, x.All(),
|
||||
layer_weights->vit.layer_norm_0_scale.PackedScale1(),
|
||||
layer_weights->vit.layer_norm_0_bias.PackedScale1(),
|
||||
activations.pre_att_rms_out.All(), model_dim);
|
||||
LayerNormBatched(x, layer_weights->vit.layer_norm_0_scale,
|
||||
layer_weights->vit.layer_norm_0_bias,
|
||||
activations.pre_att_rms_out);
|
||||
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
// y ~ att_sums
|
||||
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
|
||||
|
||||
// x = out["+sa"] = x + y
|
||||
AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), model_dim);
|
||||
AddFromBatched(activations.att_sums, x);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ bf_pre_ffw_rms_out
|
||||
LayerNormBatched(num_tokens, x.All(),
|
||||
layer_weights->vit.layer_norm_1_scale.PackedScale1(),
|
||||
layer_weights->vit.layer_norm_1_bias.PackedScale1(),
|
||||
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
// y ~ pre_ffw_rms_out
|
||||
LayerNormBatched(x, layer_weights->vit.layer_norm_1_scale,
|
||||
layer_weights->vit.layer_norm_1_bias,
|
||||
activations.pre_ffw_rms_out);
|
||||
|
||||
// y = out["mlp"] = MlpBlock(...)(y)
|
||||
// y ~ ffw_out
|
||||
FFWVit(activations, num_tokens, layer_weights);
|
||||
FFWVit(activations, layer_weights);
|
||||
|
||||
// x = out["+mlp"] = x + y
|
||||
AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), model_dim);
|
||||
AddFromBatched(activations.ffw_out, x);
|
||||
}
|
||||
|
||||
// Prefill() and Transformer() increment positions in-place.
|
||||
|
|
@ -1094,7 +1083,7 @@ HWY_NOINLINE void Prefill(
|
|||
// intensity, and so we are eventually compute-limited. We could devote some
|
||||
// threads to parallelizing over queries, but for simplicity we assign them
|
||||
// all to MatMul.
|
||||
const size_t max_tbatch_size = activations.x.BatchSize();
|
||||
const size_t max_tbatch_size = activations.x.Rows();
|
||||
|
||||
// For each query. `qi` is within the batch, not the global query index.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
|
|
@ -1131,6 +1120,7 @@ HWY_NOINLINE void Prefill(
|
|||
tbatch_start += max_tbatch_size) {
|
||||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
|
||||
activations.SetBatchSize(tbatch_size);
|
||||
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
size_t image_token_position = 0;
|
||||
|
|
@ -1201,13 +1191,14 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||
// image_patches is (256, 14 * 14 * 3)
|
||||
// This could be done as one MatMul like:
|
||||
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
|
||||
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
|
||||
// kPatchSize), MatPadding::kPacked);
|
||||
// [Get patches]
|
||||
// MatMul(
|
||||
// MatFromBatch(kVitSeqLen, image_patches),
|
||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
// RowPtrF(activations.x.All(), kVitModelDim));
|
||||
// RowPtrF(activations.x.Row(0), kVitModelDim));
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
|
|
@ -1216,11 +1207,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||
image_patches[i].get(),
|
||||
weights.vit_img_embedding_bias.PackedScale1(),
|
||||
activations.x.Batch(i), activations.env->ctx.pools.Pool(0));
|
||||
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
|
||||
}
|
||||
// Add position embeddings.
|
||||
AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(),
|
||||
seq_len * model_dim);
|
||||
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
|
||||
}
|
||||
|
||||
// Prefills the image tokens with the ViT encoder.
|
||||
|
|
@ -1232,7 +1222,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
PROFILER_ZONE("Gen.PrefillVit");
|
||||
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
|
||||
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
|
||||
HWY_ASSERT(num_tokens == activations.x.BatchSize());
|
||||
HWY_ASSERT(num_tokens == activations.x.Rows());
|
||||
// Embed the image patches.
|
||||
EmbedImagePatches(image, weights, activations);
|
||||
// Go through all layers.
|
||||
|
|
@ -1243,24 +1233,21 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
|
||||
}
|
||||
// Final Layernorm.
|
||||
LayerNormBatched(num_tokens, activations.x.All(),
|
||||
weights.vit_encoder_norm_scale.PackedScale1(),
|
||||
weights.vit_encoder_norm_bias.PackedScale1(),
|
||||
activations.x.All(), vit_model_dim);
|
||||
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
|
||||
weights.vit_encoder_norm_bias, activations.x);
|
||||
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
activations.x = AvgPool4x4(activations.x);
|
||||
|
||||
// Apply soft embedding norm before input projection.
|
||||
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(),
|
||||
vit_model_dim);
|
||||
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0,
|
||||
activations.x.Row(0), vit_model_dim);
|
||||
}
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||
MatMul(activations.x, weights.vit_img_head_kernel,
|
||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||
RowPtrFromBatch(activations.env->ctx.allocator, image_tokens));
|
||||
RowPtrFromMat(activations.env->ctx.allocator, image_tokens));
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
|
|
@ -1272,7 +1259,6 @@ HWY_NOINLINE void Transformer(
|
|||
Activations& activations, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, const LayersOutputFunc& layers_output,
|
||||
const ActivationsObserverFunc& activations_observer) {
|
||||
const size_t model_dim = weights.weights_config.model_dim;
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
HWY_DASSERT(queries_prefix_end.size() == num_queries);
|
||||
|
|
@ -1302,8 +1288,7 @@ HWY_NOINLINE void Transformer(
|
|||
}
|
||||
}
|
||||
|
||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(),
|
||||
activations.x.All(), model_dim);
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||
|
||||
if (activations_observer) {
|
||||
activations_observer(queries_pos, -1, activations);
|
||||
|
|
@ -1395,18 +1380,18 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
|||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||
bool all_queries_eos = true;
|
||||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||
MatMul(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, *activations.env,
|
||||
RowPtrFromBatch(activations.env->ctx.allocator, activations.logits));
|
||||
RowPtrFromMat(activations.env->ctx.allocator, activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
|
||||
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
|
||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
|
@ -1460,7 +1445,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
||||
HWY_ASSERT(num_queries <= activations.x.Rows());
|
||||
HWY_ASSERT(queries_pos_in.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
|
|
@ -1475,12 +1460,11 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||
// allocate prefill_activations, otherwise reuse.
|
||||
const bool use_prefill_activations =
|
||||
runtime_config.prefill_tbatch_size > activations.x.BatchSize();
|
||||
Activations prefill_activations(weights.weights_config);
|
||||
if (use_prefill_activations) {
|
||||
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
|
||||
runtime_config.prefill_tbatch_size > activations.x.Rows();
|
||||
Activations prefill_activations(
|
||||
weights.weights_config,
|
||||
use_prefill_activations ? runtime_config.prefill_tbatch_size : 0,
|
||||
activations.env);
|
||||
}
|
||||
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
query_idx_start, weights,
|
||||
use_prefill_activations ? prefill_activations : activations,
|
||||
|
|
@ -1534,8 +1518,7 @@ void GenerateSingleT(const ModelStore& model,
|
|||
const size_t qbatch_start = 0;
|
||||
|
||||
// TODO: move into Gemma?
|
||||
Activations activations(model.Config());
|
||||
activations.Allocate(kNumQueries, env);
|
||||
Activations activations(model.Config(), kNumQueries, env);
|
||||
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
|
|
@ -1558,7 +1541,7 @@ void GenerateBatchT(const ModelStore& model,
|
|||
TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
||||
// Griffin does not support query batching.
|
||||
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||
for (const LayerConfig& layer_config : model.Config().layer_configs) {
|
||||
|
|
@ -1568,14 +1551,14 @@ void GenerateBatchT(const ModelStore& model,
|
|||
}
|
||||
}
|
||||
|
||||
Activations activations(model.Config());
|
||||
activations.Allocate(max_qbatch_size, env);
|
||||
Activations activations(model.Config(), max_qbatch_size, env);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
// Generate one batch of tokens from `qbatch_size` queries.
|
||||
const size_t qbatch_size =
|
||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||
activations.SetBatchSize(qbatch_size);
|
||||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
||||
qbatch_size);
|
||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
|
|
@ -1601,8 +1584,7 @@ void GenerateImageTokensT(const ModelStore& model,
|
|||
ModelConfig vit_config = GetVitConfig(model.Config());
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config);
|
||||
prefill_activations.Allocate(vit_config.seq_len, env);
|
||||
Activations prefill_activations(vit_config, vit_config.seq_len, env);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations);
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@
|
|||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include "ops/matmul.h" // MMStorage::kMax*
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
|
|
@ -74,9 +75,9 @@ using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
|||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
|
||||
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
|
||||
// corresponds to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = RowVectorBatch<float>;
|
||||
// ImageTokens are represented as a matrix, where each row corresponds
|
||||
// to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = MatStorageT<float>;
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
|
|
|
|||
|
|
@ -15,91 +15,69 @@
|
|||
|
||||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <algorithm> // std::copy
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "util/mat.h" // ZeroInit
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ZeroBytes
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void KVCache::ZeroGriffinCache() {
|
||||
if (conv1d_cache_size != 0) {
|
||||
hwy::ZeroBytes(conv1d_cache.get(),
|
||||
conv1d_cache_size * sizeof(conv1d_cache[0]));
|
||||
if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache);
|
||||
if (rglru_cache.HasPtr()) ZeroInit(rglru_cache);
|
||||
}
|
||||
if (rglru_cache_size != 0) {
|
||||
hwy::ZeroBytes(rglru_cache.get(),
|
||||
rglru_cache_size * sizeof(rglru_cache[0]));
|
||||
|
||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||
size_t conv1d_width = 0;
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
return conv1d_width == 0 ? 0 : conv1d_width - 1;
|
||||
}
|
||||
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
// prefill at a time.
|
||||
KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache kv_cache = {};
|
||||
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
|
||||
: griffin_layers(
|
||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
|
||||
griffin_conv1d_cols(GriffinConv1dCols(config)),
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
conv1d_cache(
|
||||
"conv1d_cache",
|
||||
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
|
||||
MatPadding::kOdd),
|
||||
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
|
||||
MatPadding::kOdd) {
|
||||
// TODO: move to MatStorageT.
|
||||
const size_t size_cache_pos = config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
// Allocate more so that prefill can always access one batch, even if
|
||||
// near the end of the sequence.
|
||||
kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size;
|
||||
kv_cache.kv_cache =
|
||||
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
||||
seq_len = config.seq_len + prefill_tbatch_size;
|
||||
kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
}
|
||||
|
||||
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (num_griffin_layers > 0) {
|
||||
uint32_t conv1d_width = 0;
|
||||
for (const auto& layer_config : weights_config.layer_configs) {
|
||||
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
const size_t conv1d_cache_size =
|
||||
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
|
||||
weights_config.model_dim;
|
||||
kv_cache.conv1d_cache_size = conv1d_cache_size;
|
||||
if (conv1d_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||
}
|
||||
|
||||
const size_t rglru_cache_size =
|
||||
num_griffin_layers * weights_config.model_dim;
|
||||
kv_cache.rglru_cache_size = rglru_cache_size;
|
||||
if (rglru_cache_size != 0) {
|
||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||
}
|
||||
} // num_griffin_layers
|
||||
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
KVCache KVCache::Copy(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
|
||||
KVCache copy(weights_config, prefill_tbatch_size);
|
||||
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
|
||||
kv_cache_copy.kv_cache.get());
|
||||
copy.kv_cache.get());
|
||||
}
|
||||
|
||||
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
if (num_griffin_layers > 0) {
|
||||
if (conv1d_cache_size != 0) {
|
||||
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
|
||||
kv_cache_copy.conv1d_cache.get());
|
||||
if (conv1d_cache.HasPtr()) {
|
||||
CopyMat(conv1d_cache, copy.conv1d_cache);
|
||||
}
|
||||
if (rglru_cache_size != 0) {
|
||||
std::copy(rglru_cache.get(),
|
||||
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
|
||||
kv_cache_copy.rglru_cache.get());
|
||||
if (rglru_cache.HasPtr()) {
|
||||
CopyMat(rglru_cache, copy.rglru_cache);
|
||||
}
|
||||
}
|
||||
return kv_cache_copy;
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -19,33 +19,31 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct KVCache {
|
||||
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
||||
KVCache() = default; // for std::vector.
|
||||
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
|
||||
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||
|
||||
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
|
||||
size_t conv1d_cache_size = 0;
|
||||
|
||||
// kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||
size_t rglru_cache_size = 0;
|
||||
// Returns a deep copy of the KVCache.
|
||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
|
||||
size_t griffin_layers = 0;
|
||||
size_t griffin_conv1d_cols = 0;
|
||||
// griffin_layers, griffin_conv1d_cols * config.model_dim
|
||||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
|
||||
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
||||
// and rglru_cache.
|
||||
void ZeroGriffinCache();
|
||||
|
||||
static KVCache Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size);
|
||||
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
||||
|
||||
// Returns a deep copy of the KVCache.
|
||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
36
gemma/run.cc
36
gemma/run.cc
|
|
@ -95,24 +95,25 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(inference, gen);
|
||||
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
const size_t pool_dim = config.vit_config.pool_dim;
|
||||
ImageTokens image_tokens(
|
||||
"image_tokens",
|
||||
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
|
||||
config.model_dim)
|
||||
: Extents2D(0, 0),
|
||||
MatPadding::kOdd);
|
||||
if (have_image) {
|
||||
size_t pool_dim = gemma.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens =
|
||||
ImageTokens(gemma.Env().ctx.allocator,
|
||||
Extents2D(gemma.GetModelConfig().vit_config.seq_len /
|
||||
(pool_dim * pool_dim),
|
||||
gemma.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
gemma.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||
config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
|
|
@ -138,7 +139,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::cerr << "." << std::flush;
|
||||
}
|
||||
return true;
|
||||
} else if (gemma.GetModelConfig().IsEOS(token)) {
|
||||
} else if (config.IsEOS(token)) {
|
||||
if (inference.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
|
|
@ -191,8 +192,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
gemma.GetModelConfig().wrapping, abs_pos,
|
||||
prompt_string, image_tokens.BatchSize());
|
||||
config.wrapping, abs_pos, prompt_string,
|
||||
image_tokens.Rows());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
|
|
@ -203,8 +204,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
// runtime_config.prefill_tbatch_size = prompt_size;
|
||||
} else {
|
||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
gemma.GetModelConfig().wrapping, abs_pos,
|
||||
prompt_string);
|
||||
config.wrapping, abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
|
|
@ -228,8 +228,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!inference.multiturn ||
|
||||
gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(inference, gen);
|
||||
} else {
|
||||
|
|
@ -254,8 +253,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
MatMulEnv env(MakeMatMulEnv(threading));
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
const Gemma gemma(loader, env);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
|
|
|
|||
|
|
@ -92,9 +92,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
const Extents2D B_extents(N, K); // already transposed
|
||||
const Extents2D C_extents(M, N);
|
||||
|
||||
RowVectorBatch<TC> c_slow_batch =
|
||||
AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
|
||||
|
||||
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
|
||||
if (add) {
|
||||
|
|
@ -104,11 +103,9 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
|
||||
MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
const auto A = ConstMatFromWeights(a);
|
||||
const auto B = ConstMatFromWeights(b_trans);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
|
||||
|
||||
// Fewer reps for large batch sizes, which take longer.
|
||||
const size_t num_samples = M < 32 ? 20 : 12;
|
||||
|
|
@ -118,7 +115,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
// Ensure usage conditions are set before autotuning. Both binding and
|
||||
// spinning may materially affect the choice of config. No harm in calling
|
||||
// BindB/C if there is a single package: they will be a no-op.
|
||||
BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel);
|
||||
BindB(allocator, B_extents.rows, sizeof(TC), ConstMat<TB>(b_trans),
|
||||
env.parallel);
|
||||
BindC(allocator, A_extents.rows, C, env.parallel);
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
|
|
@ -133,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
// Until enough samples collected *after* autotuning finished:
|
||||
while (times.size() < num_samples) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
per_key = MatMul(A, B, add_row, env, C);
|
||||
per_key = MatMul(a, b_trans, add_row, env, C);
|
||||
const double t1 = hwy::platform::Now();
|
||||
double elapsed = t1 - t0;
|
||||
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/test_util.h"
|
||||
|
|
@ -999,7 +1000,6 @@ struct TestShortDotsT {
|
|||
const size_t N = hn::Lanes(d);
|
||||
const hn::ScalableTag<float> df; // for CallDot
|
||||
|
||||
const Allocator& allocator = gcpp::ThreadingContext::Get().allocator;
|
||||
CompressWorkingSet work;
|
||||
std::mt19937 rng;
|
||||
rng.seed(12345);
|
||||
|
|
@ -1010,22 +1010,22 @@ struct TestShortDotsT {
|
|||
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
|
||||
// hence they require padding to one vector.
|
||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
RowVectorBatch<float> raw_w(allocator, Extents2D(1, padded_num));
|
||||
RowVectorBatch<float> raw_v(allocator, Extents2D(1, padded_num));
|
||||
RowVectorBatch<Packed> weights(allocator, Extents2D(1, packed_num));
|
||||
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
||||
RowVectorBatch<T> vectors(allocator, Extents2D(1, num));
|
||||
const PackedSpan<T> v(vectors.Batch(0), num);
|
||||
MatStorageT<float> raw_w("raw_w", padded_num);
|
||||
MatStorageT<float> raw_v("raw_v", padded_num);
|
||||
MatStorageT<Packed> weights("weights", padded_num);
|
||||
const PackedSpan<Packed> w = weights.Span();
|
||||
MatStorageT<T> vectors("vectors", padded_num);
|
||||
const PackedSpan<T> v = vectors.Span();
|
||||
|
||||
RowVectorBatch<double> bufs(allocator, Extents2D(1, num));
|
||||
double* HWY_RESTRICT buf = bufs.Batch(0);
|
||||
MatStorageT<double> bufs("bufs", num);
|
||||
double* HWY_RESTRICT buf = bufs.Packed();
|
||||
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
GenerateWellConditionedInputs(num, raw_w.All(), rng, w, work);
|
||||
GenerateWellConditionedInputs(num, raw_v.All(), rng, v, work);
|
||||
GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work);
|
||||
GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work);
|
||||
|
||||
const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf);
|
||||
const float dot_exact =
|
||||
ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf);
|
||||
float dots[kVariants];
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
// Here Packed is not always float, so we must not call kDouble.
|
||||
|
|
@ -1106,7 +1106,6 @@ void TestAllDot() {
|
|||
threading_args.max_lps = kMaxWorkers - 1;
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
ThreadingContext& ctx = ThreadingContext::Get();
|
||||
const Allocator& allocator = ctx.allocator;
|
||||
|
||||
{ // ensure no profiler zones are active
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -1118,16 +1117,17 @@ void TestAllDot() {
|
|||
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
RowVectorBatch<float> a(allocator, Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<float> b(allocator, Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<double> bufs(allocator, Extents2D(kMaxWorkers, num));
|
||||
MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
|
||||
MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
|
||||
MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num),
|
||||
MatPadding::kOdd);
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
ctx.pools.Cluster(0, 0).Run(
|
||||
0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Batch(thread);
|
||||
float* HWY_RESTRICT pb = b.Batch(thread);
|
||||
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
||||
float* HWY_RESTRICT pa = a.Row(thread);
|
||||
float* HWY_RESTRICT pb = b.Row(thread);
|
||||
double* HWY_RESTRICT buf = bufs.Row(thread);
|
||||
const PackedSpan<const float> a_span(pa, num);
|
||||
DotStats& stats = all_stats[thread];
|
||||
const double cond =
|
||||
|
|
|
|||
|
|
@ -693,7 +693,6 @@ class MMScaleDemoteAdd {
|
|||
// We manually unroll 2x for higher IPC in batch=1.
|
||||
size_t col_c = range_nc.begin();
|
||||
if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) {
|
||||
HWY_UNROLL(1)
|
||||
for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) {
|
||||
VD a0, a1; // unused if !kAdd
|
||||
if constexpr (kAdd) {
|
||||
|
|
@ -860,9 +859,8 @@ class MMScaleDemoteAdd {
|
|||
class MMPerPackage {
|
||||
public:
|
||||
template <typename TA>
|
||||
MMPerPackage(const ConstMat<TA>& A, const MMArgs& args,
|
||||
const MMConfig& config, size_t pkg_idx,
|
||||
const IndexRange& range_np)
|
||||
MMPerPackage(const MatPtrT<TA>& A, const MMArgs& args, const MMConfig& config,
|
||||
size_t pkg_idx, const IndexRange& range_np)
|
||||
: args_(args),
|
||||
pkg_idx_(pkg_idx),
|
||||
// May be overwritten with a view of A, if already BF16.
|
||||
|
|
@ -1114,12 +1112,12 @@ class MMPerPackage {
|
|||
});
|
||||
}
|
||||
|
||||
// Decompresses all `M x K` from `A` into `pkg_A`. Assumes `TA` is a seekable
|
||||
// Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable
|
||||
// type (i.e., not NUQ) so we can use pointer arithmetic.
|
||||
template <typename TA>
|
||||
HWY_NOINLINE void DoDecompressA(const ConstMat<TA>& A, MMParA par_a) const {
|
||||
const IndexRange all_M(0, A.extents.rows);
|
||||
const IndexRange all_K(0, A.extents.cols);
|
||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
|
||||
const IndexRange all_M(0, A.Rows());
|
||||
const IndexRange all_K(0, A.Cols());
|
||||
HWY_DASSERT(all_K.Num() == A_.Cols());
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
|
|
@ -1131,10 +1129,9 @@ class MMPerPackage {
|
|||
const size_t col0 = range_K.begin();
|
||||
const size_t cols = range_K.Num();
|
||||
// otherwise, padding overwrites neighbors
|
||||
HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols);
|
||||
HWY_DASSERT(cols % NBF == 0 || cols == A.Cols());
|
||||
for (size_t row_a : range_M) {
|
||||
const PackedSpan<const TA> from =
|
||||
MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
|
||||
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
|
||||
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
|
||||
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
||||
// Verify that we zero-padded.
|
||||
|
|
@ -1174,18 +1171,14 @@ class MMPerPackage {
|
|||
|
||||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename TA>
|
||||
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
|
||||
HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
|
||||
const Allocator& allocator = args_.env->ctx.allocator;
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
// Only if no zero-padding required.
|
||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.extents.cols % NBF == 0)) {
|
||||
const BF16* pos = A.ptr + A.Row(0);
|
||||
return RowPtrBF(allocator, const_cast<BF16*>(pos), A.extents.cols,
|
||||
A.Stride());
|
||||
}
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A);
|
||||
}
|
||||
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
|
|
@ -1196,7 +1189,7 @@ class MMPerPackage {
|
|||
// First call: generate candidates.
|
||||
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4};
|
||||
if (A.extents.rows == 1) {
|
||||
if (A.Rows() == 1) {
|
||||
candidates.push_back(MMParA::kNone);
|
||||
} else {
|
||||
candidates.push_back(MMParA::kM);
|
||||
|
|
@ -1247,7 +1240,7 @@ class MMPerPackage {
|
|||
|
||||
const MMArgs args_; // copy for locality
|
||||
const size_t pkg_idx_;
|
||||
RowPtrBF A_; // points into A or storage.
|
||||
RowPtrBF A_; // points into A or pkg_A.
|
||||
|
||||
const IndexRange range_np_;
|
||||
// From MMConfig:
|
||||
|
|
@ -1276,9 +1269,8 @@ struct MMImpl {
|
|||
// Called from `MatMul` from two places: either with the next autotune config,
|
||||
// or with the best config.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A,
|
||||
const ConstMat<TB>& B, const RowPtr<TC>& C,
|
||||
const MMArgs& args,
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
||||
const RowPtr<TC>& C, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
MMZone matmul_zone;
|
||||
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
||||
|
|
@ -1313,7 +1305,7 @@ struct MMImpl {
|
|||
//
|
||||
// Uses considerable stack space: at least 40 KiB per thread.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
|
|
@ -1340,7 +1332,7 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
|||
MMPerKey& per_key = env.per_key[index];
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
|
||||
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add,
|
||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.scale, add,
|
||||
env.storage.Partial());
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
||||
|
|
@ -1396,6 +1388,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
|||
return &per_key;
|
||||
}
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
return MatMul(A, ConstMat<TB>(B), add, env, C);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
72
ops/matmul.h
72
ops/matmul.h
|
|
@ -21,6 +21,7 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
|
|
@ -207,24 +208,28 @@ class MMStorage {
|
|||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
static constexpr size_t kMaxKC = 8 * 1024;
|
||||
|
||||
// Internally threaded; must not be called concurrently with the same
|
||||
// `ThreadingContext` (used via `parallel`).
|
||||
MMStorage(const Allocator& allocator, MMParallel& parallel)
|
||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||
// one instance of the maximum matrix extents because threads write at
|
||||
// false-sharing-free granularity.
|
||||
: partial_storage_(
|
||||
AllocateAlignedRows<double>(allocator, Extents2D(kMaxM, kMaxN))),
|
||||
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
||||
MatPadding::kOdd),
|
||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
partial_(allocator, partial_storage_.All(), kMaxN,
|
||||
StrideForCyclicOffsets(kMaxN, allocator.Quantum<double>())) {
|
||||
partial_(allocator, partial_storage_.Row(0), kMaxN,
|
||||
partial_storage_.Stride()) {
|
||||
// Per-package allocation so each can decompress A into its own copy.
|
||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||
pkg_A_[pkg_idx] =
|
||||
AllocateAlignedRows<BF16>(allocator, Extents2D(kMaxM, kMaxK));
|
||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
|
||||
|
||||
if (allocator.ShouldBind()) {
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
if (!allocator.BindMemory(pkg_A_[pkg_idx].All(),
|
||||
pkg_A_[pkg_idx].NumBytes(), node)) {
|
||||
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
||||
pkg_A_[pkg_idx]->ElementBytes();
|
||||
bytes = hwy::RoundDownTo(bytes, allocator.QuantumBytes());
|
||||
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
|
||||
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
||||
}
|
||||
}
|
||||
|
|
@ -234,22 +239,20 @@ class MMStorage {
|
|||
BindC(allocator, kMaxM, partial_, parallel);
|
||||
}
|
||||
|
||||
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
|
||||
// non-const, because `RowPtr` requires a non-const pointer.
|
||||
// Returns per-package matrix view.
|
||||
RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
|
||||
const Extents2D& extents) {
|
||||
const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
const size_t stride =
|
||||
StrideForCyclicOffsets(extents.cols, allocator.Quantum<BF16>());
|
||||
return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride);
|
||||
return RowPtrBF(allocator, const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
|
||||
extents.cols, pkg_A_[pkg_idx]->Stride());
|
||||
}
|
||||
|
||||
RowPtrD Partial() const { return partial_; }
|
||||
|
||||
private:
|
||||
RowVectorBatch<BF16> pkg_A_[MMParallel::kMaxPackages];
|
||||
RowVectorBatch<double> partial_storage_;
|
||||
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
||||
MatStorageT<double> partial_storage_;
|
||||
RowPtrD partial_;
|
||||
};
|
||||
|
||||
|
|
@ -608,6 +611,8 @@ struct MMPerKey {
|
|||
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
|
||||
// `MatMulEnv`.
|
||||
struct MatMulEnv {
|
||||
// Internally threaded; must not be called concurrently with the same
|
||||
// `ThreadingContext`.
|
||||
explicit MatMulEnv(ThreadingContext& ctx);
|
||||
|
||||
ThreadingContext& ctx;
|
||||
|
|
@ -679,8 +684,8 @@ struct MMZone {
|
|||
#endif // PROFILER_ENABLED
|
||||
|
||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
||||
// `ofs` required for compressed T.
|
||||
// This differs from `RowPtr` in supporting the `ofs` required for compressed T.
|
||||
// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr.
|
||||
template <typename T>
|
||||
struct ConstMat {
|
||||
ConstMat() = default;
|
||||
|
|
@ -689,6 +694,12 @@ struct ConstMat {
|
|||
HWY_DASSERT(ptr != nullptr);
|
||||
HWY_DASSERT(stride >= extents.cols);
|
||||
}
|
||||
// Non-explicit so that we can pass `MatPtr` directly to MatMul.
|
||||
ConstMat(const MatPtrT<T>& m)
|
||||
: ConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride()) {
|
||||
scale = m.Scale();
|
||||
}
|
||||
|
||||
size_t Row(size_t r) const {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (r >= extents.rows) {
|
||||
|
|
@ -727,31 +738,6 @@ struct ConstMat {
|
|||
size_t ofs;
|
||||
};
|
||||
|
||||
// For deducing T.
|
||||
template <typename T>
|
||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
||||
size_t stride) {
|
||||
return ConstMat<T>(ptr, extents, stride);
|
||||
}
|
||||
|
||||
// For A argument to MatMul (activations).
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
||||
const RowVectorBatch<T>& row_vectors) {
|
||||
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
||||
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
||||
Extents2D(batch_size, row_vectors.Cols()),
|
||||
row_vectors.Stride());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
|
||||
ConstMat<T> mat =
|
||||
MakeConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride());
|
||||
mat.scale = m.Scale();
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename TB>
|
||||
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
|
||||
const ConstMat<TB>& B, MMParallel& parallel) {
|
||||
|
|
|
|||
|
|
@ -57,10 +57,10 @@ namespace HWY_NAMESPACE {
|
|||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
|
||||
double MaxRowAbsSum(const MatStorageT<float>& a) {
|
||||
double max_row_abs_sum = 0.0;
|
||||
for (size_t r = 0; r < a.BatchSize(); r++) {
|
||||
const float* row = a.Batch(r);
|
||||
for (size_t r = 0; r < a.Rows(); r++) {
|
||||
const float* row = a.Row(r);
|
||||
double row_abs_sum = 0.0;
|
||||
for (size_t c = 0; c < a.Cols(); c++) {
|
||||
row_abs_sum += hwy::ScalarAbs(row[c]);
|
||||
|
|
@ -71,11 +71,11 @@ double MaxRowAbsSum(const RowVectorBatch<float>& a) {
|
|||
}
|
||||
|
||||
// Returns the maximum absolute value of `a`.
|
||||
float MaxAbs(const RowVectorBatch<float>& a) {
|
||||
float MaxAbs(const MatStorageT<float>& a) {
|
||||
float max_abs = 0.0f;
|
||||
for (size_t c = 0; c < a.Cols(); c++) {
|
||||
for (size_t r = 0; r < a.BatchSize(); r++) {
|
||||
const float* row = a.Batch(r);
|
||||
for (size_t r = 0; r < a.Rows(); r++) {
|
||||
const float* row = a.Row(r);
|
||||
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
|
||||
}
|
||||
}
|
||||
|
|
@ -84,33 +84,29 @@ float MaxAbs(const RowVectorBatch<float>& a) {
|
|||
|
||||
// B is already transposed.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t cols = A.extents.cols;
|
||||
const size_t B_rows = B.extents.rows;
|
||||
const size_t cols = A.Cols();
|
||||
const size_t B_rows = B.Rows();
|
||||
// Round up for DecompressAndZeroPad.
|
||||
RowVectorBatch<float> a_batch =
|
||||
AllocateAlignedRows<float>(allocator, A.extents);
|
||||
RowVectorBatch<float> b_trans_batch =
|
||||
AllocateAlignedRows<float>(allocator, B.extents);
|
||||
RowVectorBatch<float> c_batch =
|
||||
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
|
||||
RowVectorBatch<float> c_slow_batch =
|
||||
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
|
||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||
for (size_t m = 0; m < A.extents.rows; ++m) {
|
||||
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
|
||||
a_batch.Batch(m), cols);
|
||||
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
|
||||
MatStorageT<float> a_batch("a_batch", A.Extents(), MatPadding::kOdd);
|
||||
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
|
||||
MatPadding::kOdd);
|
||||
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
|
||||
MatPadding::kOdd);
|
||||
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
|
||||
MatPadding::kOdd);
|
||||
for (size_t m = 0; m < A.Rows(); ++m) {
|
||||
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
|
||||
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
|
||||
B_rows);
|
||||
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
|
||||
c_slow_batch.Batch(m), B_rows);
|
||||
c_slow_batch.Row(m), B_rows);
|
||||
}
|
||||
for (size_t n = 0; n < B_rows; ++n) {
|
||||
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0,
|
||||
b_trans_batch.Batch(n), cols);
|
||||
DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n),
|
||||
cols);
|
||||
}
|
||||
|
||||
// MatMul rounds inputs to BF16, so error is proportional to the max input
|
||||
|
|
@ -130,10 +126,10 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
|||
}
|
||||
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
|
||||
|
||||
for (size_t r = 0; r < A.extents.rows; r++) {
|
||||
const float* expected_row = c_slow_batch.Batch(r);
|
||||
const float* actual_row = c_batch.Batch(r);
|
||||
for (size_t c = 0; c < B.extents.rows; c++) {
|
||||
for (size_t r = 0; r < A.Rows(); r++) {
|
||||
const float* expected_row = c_slow_batch.Row(r);
|
||||
const float* actual_row = c_batch.Row(r);
|
||||
for (size_t c = 0; c < B.Rows(); c++) {
|
||||
const double expected_value = static_cast<double>(expected_row[c]);
|
||||
const double actual_value = static_cast<double>(actual_row[c]);
|
||||
const bool in_range = expected_value - tolerance <= actual_value &&
|
||||
|
|
@ -157,18 +153,17 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
|||
|
||||
// B is already transposed.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||
HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
|
||||
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
// TA can be any Packed except NuqStream because it uses pointer
|
||||
// arithmetic, because it is the second argument to Dot, which does not
|
||||
// support a v_ofs.
|
||||
static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32");
|
||||
const float scale = A.scale * B.scale;
|
||||
const float scale = A.Scale() * B.Scale();
|
||||
|
||||
const hn::ScalableTag<float> df; // lane type is ignored
|
||||
const PackedSpan<const TB> b_span =
|
||||
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
|
||||
const PackedSpan<const TB> b_span = B.Span();
|
||||
const IndexRange all_rows_c(0, A.Extents().rows);
|
||||
const IndexRange all_cols_c(0, C.Cols());
|
||||
|
||||
|
|
@ -191,8 +186,8 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
|||
for (size_t c : cols_c) {
|
||||
const float add = add_row ? add_row[c] : 0.0f;
|
||||
C_row[c] = hwy::ConvertScalarTo<TC>(
|
||||
add + scale * Dot(df, b_span, c * B.Stride(),
|
||||
A.ptr + A.Row(r), A.extents.cols));
|
||||
add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r),
|
||||
A.Cols()));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -225,26 +220,23 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
|
||||
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
|
||||
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
|
||||
RowVectorBatch<TC> c_slow_batch =
|
||||
AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
|
||||
|
||||
MatStorageT<float> add_storage =
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
||||
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
|
||||
add_storage.SetScale(1.0f);
|
||||
|
||||
const auto A = ConstMatFromWeights(a);
|
||||
const auto B = ConstMatFromWeights(b_trans);
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C_slow = RowPtrFromBatch(allocator, c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
|
||||
const RowPtr<TC> C_slow = RowPtrFromMat(allocator, c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
|
||||
|
||||
MatMulSlow(A, B, add_row, env, C_slow);
|
||||
MatMulSlow(a, b_trans, add_row, env, C_slow);
|
||||
// A few reps to get coverage of the various autotuned code paths.
|
||||
for (size_t rep = 0; rep < 16; ++rep) {
|
||||
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
|
||||
AssertClose(A, B, C_slow, C, line);
|
||||
MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C);
|
||||
AssertClose(a, b_trans, C_slow, C, line);
|
||||
if (per_key->autotune.Best()) break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
257
ops/ops-inl.h
257
ops/ops-inl.h
|
|
@ -189,10 +189,11 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
|||
|
||||
} // namespace detail
|
||||
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
|
||||
const WeightT* HWY_RESTRICT weight,
|
||||
OutT* HWY_RESTRICT out,
|
||||
// `x_ofs` is the offset within `x`, required for NuqStream.
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||
const WT* HWY_RESTRICT weight,
|
||||
size_t w_ofs, OT* HWY_RESTRICT out,
|
||||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
|
|
@ -203,17 +204,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
|
|||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
|
||||
|
||||
const auto packed_w = MakeSpan(weight, size);
|
||||
const auto packed_v = MakeSpan(x, size);
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||
const auto packed_out = MakeSpan(out, size);
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
|
||||
HWY_DASSERT(size % (2 * NF) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * NF) {
|
||||
VF v0, v1, w0, w1;
|
||||
Decompress2(df, packed_v, i, v0, v1);
|
||||
Decompress2(df, packed_w, i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, v0);
|
||||
const VF m1 = hn::Mul(mul, v1);
|
||||
VF x0, x1, w0, w1;
|
||||
Decompress2(df, packed_x, i, x0, x1);
|
||||
Decompress2(df, packed_w, w_ofs + i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, x0);
|
||||
const VF m1 = hn::Mul(mul, x1);
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||
|
|
@ -222,9 +223,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
|
|||
}
|
||||
|
||||
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
|
||||
template <typename WeightT, typename VecT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout,
|
||||
template <typename WT, typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
||||
size_t w_ofs,
|
||||
XT* HWY_RESTRICT inout,
|
||||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
|
|
@ -235,72 +237,112 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
|
||||
|
||||
const auto packed_w = MakeSpan(weight, size);
|
||||
const auto packed_v = MakeSpan(inout, size);
|
||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||
const auto packed_x = MakeSpan(inout, size);
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
|
||||
HWY_DASSERT(size % (2 * NF) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * NF) {
|
||||
VF v0, v1, w0, w1;
|
||||
Decompress2(df, MakeConst(packed_v), i, v0, v1);
|
||||
Decompress2(df, packed_w, i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, v0);
|
||||
const VF m1 = hn::Mul(mul, v1);
|
||||
VF x0, x1, w0, w1;
|
||||
Decompress2(df, packed_x, i, x0, x1);
|
||||
Decompress2(df, packed_w, w_ofs + i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, x0);
|
||||
const VF m1 = hn::Mul(mul, x1);
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||
Compress2(df, out0, out1, packed_v, i);
|
||||
Compress2(df, out0, out1, packed_x, i);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu,
|
||||
T& mu2) {
|
||||
template <typename XT>
|
||||
HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size,
|
||||
double& mu, double& mu2) {
|
||||
HWY_ASSERT(size > 0);
|
||||
double sum = 0.0;
|
||||
double sum2 = 0.0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const float f = hwy::ConvertScalarTo<float>(a[i]);
|
||||
sum += f;
|
||||
sum2 += f * f;
|
||||
}
|
||||
mu = sum / size;
|
||||
mu2 = sum2 / size;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
// Use the existing Sum and Dot kernels for simplicity. The second pass
|
||||
// is likely not too expensive because it will be in L1.
|
||||
const double sum = Sum(df, x, size);
|
||||
// We only have one array, so calling `DecompressAndCall` instead of `Dot``
|
||||
// avoids loading the 'second' vector again.
|
||||
const double sum2 =
|
||||
DecompressAndCall(df, MakeSpan(x, size), DotKernelDouble());
|
||||
|
||||
const double inv_size = 1.0 / static_cast<double>(size);
|
||||
mu = sum * inv_size;
|
||||
mu2 = sum2 * inv_size;
|
||||
}
|
||||
|
||||
// Compare py/flax/linen/normalization.py.
|
||||
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
HWY_NOINLINE void ScalarLayerNorm(const VecT* x,
|
||||
const WeightT* HWY_RESTRICT scale,
|
||||
const WeightT* HWY_RESTRICT bias,
|
||||
OutT* out,
|
||||
size_t size) {
|
||||
constexpr float kEps = 1e-6f;
|
||||
VecT mu, mu2;
|
||||
ScalarMus(x, size, mu, mu2);
|
||||
VecT var = mu2 - mu * mu;
|
||||
VecT zero = 0.0f;
|
||||
var = HWY_MAX(var, zero);
|
||||
var = 1.0f / sqrtf(var + kEps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
const float v = hwy::ConvertScalarTo<float>(x[j]);
|
||||
const float s = hwy::ConvertScalarTo<float>(scale[j]);
|
||||
const float b = hwy::ConvertScalarTo<float>(bias[j]);
|
||||
out[j] = hwy::ConvertScalarTo<OutT>((v - mu) * s * var + b);
|
||||
// x and out may be the same.
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
||||
const WT* HWY_RESTRICT bias, OT* out, size_t size) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
double mu, mu2;
|
||||
ComputeMoments(x, size, mu, mu2);
|
||||
double var = mu2 - mu * mu;
|
||||
var = HWY_MAX(var, 0.0);
|
||||
var = 1.0 / sqrt(var + 1E-6);
|
||||
const VF vmu = hn::Set(df, static_cast<float>(mu));
|
||||
const VF vvar = hn::Set(df, static_cast<float>(var));
|
||||
const VF* HWY_RESTRICT pmu = &vmu;
|
||||
const VF* HWY_RESTRICT pvar = &vvar;
|
||||
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
const auto packed_scale = MakeSpan(scale, size);
|
||||
const auto packed_bias = MakeSpan(bias, size);
|
||||
const auto packed_out = MakeSpan(out, size);
|
||||
|
||||
// Loop body for one vector, called from main loop and remainder loop.
|
||||
const auto norm = [pmu, pvar](VF x, VF s, VF add) HWY_ATTR -> VF {
|
||||
const VF centered = hn::Sub(x, *pmu);
|
||||
const VF mul = hn::Mul(s, *pvar);
|
||||
return hn::MulAdd(centered, mul, add);
|
||||
};
|
||||
|
||||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||
VF x0, x1, s0, s1, add0, add1;
|
||||
Decompress2(df, packed_x, i, x0, x1);
|
||||
Decompress2(df, packed_scale, i, s0, s1);
|
||||
Decompress2(df, packed_bias, i, add0, add1);
|
||||
const VF n0 = norm(x0, s0, add0);
|
||||
const VF n1 = norm(x1, s1, add1);
|
||||
Compress2(df, n0, n1, packed_out, i);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x,
|
||||
const WeightT* HWY_RESTRICT weight,
|
||||
const WeightT* HWY_RESTRICT bias,
|
||||
OutT* out,
|
||||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
// For now we only delegate to the scalar version.
|
||||
// TODO: implement vectorized version.
|
||||
ScalarLayerNorm(x, weight, bias, out, size);
|
||||
const size_t remaining = size - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
|
||||
DecompressAndZeroPad(df, packed_bias, i, buf_bias, remaining);
|
||||
const VF x0 = hn::Load(df, buf_x);
|
||||
const VF x1 = hn::Load(df, buf_x + NF);
|
||||
const VF s0 = hn::Load(df, buf_scale);
|
||||
const VF s1 = hn::Load(df, buf_scale + NF);
|
||||
const VF add0 = hn::Load(df, buf_bias);
|
||||
const VF add1 = hn::Load(df, buf_bias + NF);
|
||||
const VF n0 = norm(x0, s0, add0);
|
||||
const VF n1 = norm(x1, s1, add1);
|
||||
Compress2(df, n0, n1, MakeSpan(buf_out, 2 * NF), 0);
|
||||
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||
|
|
@ -447,39 +489,56 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
|||
}
|
||||
|
||||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||
template <typename WeightT, typename OutT>
|
||||
void RMSNormBatched(size_t num_tokens, const float* activations,
|
||||
const WeightT* weights, OutT* out, const size_t model_dim) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
RMSNorm(activations + token_idx * model_dim, weights,
|
||||
out + token_idx * model_dim, model_dim);
|
||||
template <typename XT, typename OT>
|
||||
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||
MatPtrT<OT>& out) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||
HWY_DASSERT(activations.SameShape(out));
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
|
||||
out.Row(token_idx), activations.Cols());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: pass RowVectorBatch argument.
|
||||
template <typename WeightT, typename InOutT>
|
||||
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
||||
InOutT* inout, const size_t model_dim) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
||||
template <typename XT>
|
||||
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
|
||||
inout.Cols());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
void LayerNormBatched(size_t num_tokens, const VecT* x,
|
||||
const WeightT* HWY_RESTRICT weight,
|
||||
const WeightT* HWY_RESTRICT bias, OutT* out,
|
||||
const size_t size) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size);
|
||||
// x and out may be the same.
|
||||
template <typename XT, typename OT>
|
||||
void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
|
||||
const MatPtr& bias, MatPtrT<OT>& out) {
|
||||
HWY_DASSERT(weight.Cols() == bias.Cols());
|
||||
HWY_DASSERT(weight.Cols() == x.Cols());
|
||||
HWY_DASSERT(x.SameShape(out));
|
||||
|
||||
CallUpcastedSame(
|
||||
&weight, &bias, [&](const auto* weight_t, const auto* bias_t) {
|
||||
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
|
||||
LayerNorm(x.Row(token_idx), weight_t->PackedScale1(),
|
||||
bias_t->PackedScale1(), out.Row(token_idx), x.Cols());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
|
||||
float* x, const size_t model_dim) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
||||
model_dim);
|
||||
static HWY_INLINE void AddFromBatched(const MatPtrT<float>& other,
|
||||
MatPtrT<float>& x) {
|
||||
HWY_DASSERT(x.SameShape(other));
|
||||
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
|
||||
AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -743,8 +802,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
|||
HWY_ASSERT(k != 0);
|
||||
HWY_ASSERT(k <= vocab_size);
|
||||
std::vector<double> packed_token_probs;
|
||||
for (int32_t i = 0; i < vocab_size; ++i) {
|
||||
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) {
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(vocab_size); ++i) {
|
||||
if (accept_token && !accept_token(i, probabilities[i])) {
|
||||
continue;
|
||||
}
|
||||
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
|
||||
|
|
@ -756,7 +815,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
|||
|
||||
std::vector<TokenAndProb> token_probs;
|
||||
token_probs.reserve(k);
|
||||
for (int32_t i = 0; i < k; ++i) {
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(k); ++i) {
|
||||
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
|
||||
}
|
||||
return token_probs;
|
||||
|
|
@ -770,7 +829,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
TopK(probabilities, vocab_size, k, accept_token);
|
||||
std::vector<int> topk_indices(k);
|
||||
std::vector<float> topk_probs(k);
|
||||
for (int i = 0; i < k; ++i) {
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
topk_indices[i] = token_probs[i].token;
|
||||
topk_probs[i] = token_probs[i].prob;
|
||||
}
|
||||
|
|
@ -788,7 +847,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
TopK(logits, vocab_size, k, accept_token);
|
||||
std::vector<int> topk_indices(k);
|
||||
std::vector<float> topk_logits(k);
|
||||
for (int i = 0; i < token_logits.size(); ++i) {
|
||||
for (size_t i = 0; i < token_logits.size(); ++i) {
|
||||
topk_indices[i] = token_logits[i].token;
|
||||
topk_logits[i] = token_logits[i].prob;
|
||||
}
|
||||
|
|
@ -807,20 +866,20 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
|
||||
// Each output row is the average of a 4x4 block of input rows
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
MatStorageT<T> AvgPool4x4(MatStorageT<T>& input) {
|
||||
const Extents2D extents = input.Extents();
|
||||
// Input validation
|
||||
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||
// Create output with 256 rows and same number of columns
|
||||
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
||||
RowVectorBatch<T> result(allocator, Extents2D(out_rows, extents.cols));
|
||||
MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols),
|
||||
MatPadding::kOdd);
|
||||
const size_t input_dim = 64; // Input is 64×64
|
||||
const size_t output_dim = 16; // Output is 16×16
|
||||
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
|
||||
for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) {
|
||||
size_t out_idx = out_row_idx * output_dim + out_col_idx;
|
||||
T* output_row = result.Batch(out_idx);
|
||||
T* output_row = result.Row(out_idx);
|
||||
// Initialize output row to zeros
|
||||
std::fill(output_row, output_row + extents.cols, 0);
|
||||
// Average 16 row vectors from a 4x4 block
|
||||
|
|
@ -829,9 +888,9 @@ RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
|||
size_t in_row_idx = out_row_idx * 4 + i;
|
||||
size_t in_col_idx = out_col_idx * 4 + j;
|
||||
size_t in_idx = in_row_idx * input_dim + in_col_idx;
|
||||
const T* input_row = input.Batch(in_idx);
|
||||
const T* input_row = input.Row(in_idx);
|
||||
// Add each input row to the output
|
||||
// TODO(philculliton): use AddFrom in ops-inl for a vectorized loop.
|
||||
// TODO(philculliton): use AddFrom in `ops-inl` for a vectorized loop.
|
||||
for (size_t col = 0; col < extents.cols; ++col) {
|
||||
output_row[col] += input_row[col];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,23 +20,22 @@
|
|||
|
||||
#include <cmath>
|
||||
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
|
||||
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
|
||||
const Allocator& allocator, size_t qkv_dim, bool half_rope,
|
||||
double base_frequency = 10000.0) {
|
||||
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));
|
||||
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2);
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const double freq_exponents =
|
||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||
// noticeably.
|
||||
inv_timescale.Batch(0)[dim] =
|
||||
inv_timescale.Packed()[dim] =
|
||||
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
|
||||
}
|
||||
return inv_timescale;
|
||||
|
|
|
|||
141
ops/ops_test.cc
141
ops/ops_test.cc
|
|
@ -34,7 +34,7 @@
|
|||
#include "gemma/common.h" // ChooseQueryScale
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
#include "util/mat.h" // MatStorageT
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
|
@ -391,7 +391,7 @@ void TestRopeAndMulBy() {
|
|||
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv));
|
||||
MatStorageT<float> x("x", dim_qkv);
|
||||
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
|
|
@ -399,43 +399,43 @@ void TestRopeAndMulBy() {
|
|||
auto random_float = [&r, &gen] { return r(gen); };
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
x.All()[i] = random_float();
|
||||
x.Packed()[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = ChooseQueryScale(config);
|
||||
const float kmul = 1.0;
|
||||
|
||||
std::vector<float> qexpected(dim_qkv);
|
||||
std::vector<float> qactual(dim_qkv);
|
||||
std::vector<float> kexpected(dim_qkv);
|
||||
std::vector<float> kactual(dim_qkv);
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
MatStorageT<float> qexpected("qexpected", dim_qkv);
|
||||
MatStorageT<float> qactual("qactual", dim_qkv);
|
||||
MatStorageT<float> kexpected("kexpected", dim_qkv);
|
||||
MatStorageT<float> kactual("kactual", dim_qkv);
|
||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||
allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||
for (int pos = 1; pos < 500; pos++) {
|
||||
// Rope'd Q embeddings
|
||||
hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv);
|
||||
hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos);
|
||||
CopyMat(x, qactual);
|
||||
CopyMat(x, qexpected);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv,
|
||||
inv_timescale.Packed(), pos);
|
||||
RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
|
||||
<< "qIndex:" << i << "qInput:" << qactual[i];
|
||||
EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4)
|
||||
<< "qIndex:" << i << "qInput:" << qactual.Packed()[i];
|
||||
}
|
||||
|
||||
// Rope'd K embeddings
|
||||
hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv);
|
||||
hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos);
|
||||
CopyMat(x, kactual);
|
||||
CopyMat(x, kexpected);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv,
|
||||
inv_timescale.Packed(), pos);
|
||||
RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
|
||||
<< "kIndex:" << i << "kInput:" << kactual[i];
|
||||
EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4)
|
||||
<< "kIndex:" << i << "kInput:" << kactual.Packed()[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -451,10 +451,9 @@ HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
|
|||
}
|
||||
|
||||
// Supports bf16 and f32 inputs/outputs, which can be in-place.
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
|
||||
const WeightT* HWY_RESTRICT weight, OutT* out,
|
||||
size_t size) {
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight,
|
||||
OT* out, size_t size) {
|
||||
constexpr float kEps = 1e-6f;
|
||||
float ss = ScalarSquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||
|
|
@ -462,32 +461,32 @@ HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
|
|||
const float v = hwy::ConvertScalarTo<float>(x[j]);
|
||||
const float w = hwy::ConvertScalarTo<float>(weight[j]);
|
||||
// Note 1.0f centering here
|
||||
out[j] = hwy::ConvertScalarTo<OutT>((1.0f + w) * (ss * v));
|
||||
out[j] = hwy::ConvertScalarTo<OT>((1.0f + w) * (ss * v));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
template <typename XT, typename WT, typename OT>
|
||||
void TestRMSNorm(hwy::RandomState& rng) {
|
||||
constexpr size_t kSize = 128;
|
||||
HWY_ALIGN VecT vec[kSize];
|
||||
HWY_ALIGN WeightT weight[kSize];
|
||||
HWY_ALIGN OutT expected[kSize];
|
||||
HWY_ALIGN OutT actual[kSize];
|
||||
HWY_ALIGN XT vec[kSize];
|
||||
HWY_ALIGN WT weight[kSize];
|
||||
HWY_ALIGN OT expected[kSize];
|
||||
HWY_ALIGN OT actual[kSize];
|
||||
|
||||
for (size_t i = 0; i < kSize; ++i) {
|
||||
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
|
||||
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
|
||||
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
|
||||
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
|
||||
}
|
||||
|
||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||
RMSNorm(vec, weight, actual, kSize);
|
||||
RMSNorm(vec, weight, 0, actual, kSize);
|
||||
|
||||
for (size_t i = 0; i < kSize; i++) {
|
||||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||
const float a = hwy::ConvertScalarTo<float>(actual[i]);
|
||||
if (!IsNear(e, a, 1e-5f)) {
|
||||
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
|
||||
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
|
||||
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
|
||||
TypeName<WT>(), TypeName<OT>(), i, e, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -526,24 +525,64 @@ void TestLayerNormSimple() {
|
|||
}
|
||||
}
|
||||
|
||||
// Note: there is no vectorized implementation of LayerNorm yet. So this test
|
||||
// currently only checks that the scalar version can be called for the below
|
||||
// combinations of float/BF16 inputs and outputs.
|
||||
template <typename VecT, typename WeightT, typename OutT>
|
||||
// Computes mean mu and mean of squares mu2 of a vector. Used in
|
||||
// ScalarLayerNorm.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, double& mu,
|
||||
double& mu2) {
|
||||
HWY_ASSERT(size > 0);
|
||||
double sum = 0.0;
|
||||
double sum2 = 0.0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const float f = hwy::ConvertScalarTo<float>(a[i]);
|
||||
sum += f;
|
||||
sum2 += f * f;
|
||||
}
|
||||
mu = sum / size;
|
||||
mu2 = sum2 / size;
|
||||
}
|
||||
|
||||
// Compare py/flax/linen/normalization.py.
|
||||
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE void ScalarLayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
||||
const WT* HWY_RESTRICT bias, OT* out,
|
||||
size_t size) {
|
||||
constexpr double kEps = 1e-6;
|
||||
double mu, mu2;
|
||||
ScalarMus(x, size, mu, mu2);
|
||||
double var = mu2 - mu * mu;
|
||||
constexpr double kZero = 0.0;
|
||||
var = HWY_MAX(var, kZero);
|
||||
var = 1.0 / sqrt(var + kEps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
const float v = hwy::ConvertScalarTo<float>(x[j]);
|
||||
const float s = hwy::ConvertScalarTo<float>(scale[j]);
|
||||
const float b = hwy::ConvertScalarTo<float>(bias[j]);
|
||||
out[j] = hwy::ConvertScalarTo<OT>((v - mu) * s * var + b);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename XT, typename WT, typename OT>
|
||||
void TestLayerNorm(hwy::RandomState& rng) {
|
||||
constexpr size_t kSize = 128;
|
||||
VecT vec[kSize];
|
||||
WeightT weight[kSize];
|
||||
WeightT bias[kSize];
|
||||
OutT expected[kSize];
|
||||
OutT actual[kSize];
|
||||
XT vec[kSize];
|
||||
WT weight[kSize];
|
||||
WT bias[kSize];
|
||||
OT expected[kSize];
|
||||
OT actual[kSize];
|
||||
|
||||
for (size_t i = 0; i < kSize; ++i) {
|
||||
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
|
||||
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
|
||||
bias[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
|
||||
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
|
||||
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
|
||||
bias[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
|
||||
}
|
||||
|
||||
double expected_mu, expected_mu2;
|
||||
ScalarMus(vec, kSize, expected_mu, expected_mu2);
|
||||
double actual_mu, actual_mu2;
|
||||
ComputeMoments(vec, kSize, actual_mu, actual_mu2);
|
||||
|
||||
ScalarLayerNorm(vec, weight, bias, expected, kSize);
|
||||
LayerNorm(vec, weight, bias, actual, kSize);
|
||||
|
||||
|
|
@ -551,8 +590,8 @@ void TestLayerNorm(hwy::RandomState& rng) {
|
|||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||
const float a = hwy::ConvertScalarTo<float>(actual[i]);
|
||||
if (!IsNear(e, a, 1e-5f)) {
|
||||
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
|
||||
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
|
||||
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
|
||||
TypeName<WT>(), TypeName<OT>(), i, e, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,8 +55,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.value("kSFP", Type::kSFP)
|
||||
.value("kNUQ", Type::kNUQ)
|
||||
.value("kF64", Type::kF64)
|
||||
.value("kC64", Type::kC64)
|
||||
.value("kU128", Type::kU128);
|
||||
.value("kC64", Type::kC64);
|
||||
|
||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
||||
.value("kGemma", LayerAttentionType::kGemma)
|
||||
|
|
|
|||
|
|
@ -168,9 +168,9 @@ class GemmaModel {
|
|||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
const gcpp::Gemma& gemma = *gemma_.GetGemma();
|
||||
const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator;
|
||||
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||
const gcpp::ModelConfig& config = gemma.GetModelConfig();
|
||||
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||
throw std::invalid_argument("Not a PaliGemma model.");
|
||||
}
|
||||
py::buffer_info buffer = image.request();
|
||||
|
|
@ -182,14 +182,15 @@ class GemmaModel {
|
|||
float* ptr = static_cast<float*>(buffer.ptr);
|
||||
gcpp::Image c_image;
|
||||
c_image.Set(height, width, ptr);
|
||||
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
c_image.Resize(image_size, image_size);
|
||||
image_tokens_ = gcpp::ImageTokens(
|
||||
allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len,
|
||||
gemma.GetModelConfig().model_dim));
|
||||
image_tokens_.reset(new gcpp::ImageTokens(
|
||||
"image_tokens",
|
||||
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
gcpp::MatPadding::kOdd));
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_);
|
||||
gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_);
|
||||
}
|
||||
|
||||
// Generates a response to the given prompt, using the last set image.
|
||||
|
|
@ -197,9 +198,7 @@ class GemmaModel {
|
|||
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
||||
std::string prompt, size_t max_generated_tokens, float temperature,
|
||||
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||
if (image_tokens_.Cols() == 0) {
|
||||
throw std::invalid_argument("No image set.");
|
||||
}
|
||||
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
||||
const gcpp::Gemma& model = *gemma_.GetGemma();
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
|
|
@ -207,7 +206,7 @@ class GemmaModel {
|
|||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
config.accept_token = accept;
|
||||
config.image_tokens = &image_tokens_;
|
||||
config.image_tokens = image_tokens_.get();
|
||||
std::vector<int> tokens;
|
||||
if (!prompt_tokens.empty()) {
|
||||
if (!prompt.empty()) {
|
||||
|
|
@ -219,7 +218,7 @@ class GemmaModel {
|
|||
} else {
|
||||
tokens = gemma_.WrapAndTokenize(prompt);
|
||||
}
|
||||
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
size_t prefix_end = num_tokens;
|
||||
config.prefill_tbatch_size = num_tokens;
|
||||
|
|
@ -252,7 +251,7 @@ class GemmaModel {
|
|||
|
||||
private:
|
||||
gcpp::GemmaEnv gemma_;
|
||||
gcpp::ImageTokens image_tokens_;
|
||||
std::unique_ptr<gcpp::ImageTokens> image_tokens_;
|
||||
float last_prob_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -117,11 +117,11 @@ static size_t Stride(const Allocator& allocator, const MatPtr& mat,
|
|||
}
|
||||
}
|
||||
|
||||
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
|
||||
void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) {
|
||||
const bool is_nuq = mat.GetType() == Type::kNUQ;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
const size_t stride = Stride(allocator, mat, padding);
|
||||
const size_t num = mat.Rows() * stride;
|
||||
const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding);
|
||||
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
|
||||
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
||||
// might not be enough, hence add extra. `MatT` is at least one byte, which
|
||||
// is half of BF16, hence adding `VectorBytes` *elements* is enough.
|
||||
|
|
|
|||
198
util/mat.h
198
util/mat.h
|
|
@ -28,7 +28,7 @@
|
|||
#include "compression/shared.h" // Type
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "io/fields.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/allocator.h" // AlignedPtr2
|
||||
#include "util/basics.h" // Extents2D
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -47,7 +47,7 @@ class MatPtr : public IFields {
|
|||
// `name`: see `SetName`. Note that `stride` is initially `cols` and only
|
||||
// differs after deserializing, or calling `SetPtr`.
|
||||
MatPtr(const char* name, Type type, Extents2D extents)
|
||||
: rows_(static_cast<uint32_t>(extents.rows)),
|
||||
: private_rows_(static_cast<uint32_t>(extents.rows)),
|
||||
cols_(static_cast<uint32_t>(extents.cols)) {
|
||||
SetName(name);
|
||||
SetType(type);
|
||||
|
|
@ -74,7 +74,7 @@ class MatPtr : public IFields {
|
|||
bool HasPtr() const { return ptr_ != nullptr; }
|
||||
|
||||
// A single row counts as packed because there is no padding between rows.
|
||||
bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); }
|
||||
bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); }
|
||||
|
||||
const void* Packed() const {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
|
|
@ -96,17 +96,17 @@ class MatPtr : public IFields {
|
|||
// Works for any kind of padding.
|
||||
template <typename T>
|
||||
T* MutableRowT(size_t row) const {
|
||||
HWY_DASSERT(row < rows_);
|
||||
HWY_DASSERT(row < Rows());
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||
}
|
||||
template <typename T>
|
||||
T* RowT(size_t row) {
|
||||
HWY_DASSERT(row < rows_);
|
||||
HWY_DASSERT(row < Rows());
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||
}
|
||||
template <typename T>
|
||||
const T* RowT(size_t row) const {
|
||||
HWY_DASSERT(row < rows_);
|
||||
HWY_DASSERT(row < Rows());
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
|
||||
}
|
||||
|
||||
|
|
@ -118,10 +118,22 @@ class MatPtr : public IFields {
|
|||
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
|
||||
}
|
||||
|
||||
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; }
|
||||
size_t Rows() const { return rows_; }
|
||||
size_t Rows() const {
|
||||
return override_rows_ == 0 ? private_rows_ : override_rows_;
|
||||
}
|
||||
size_t Cols() const { return cols_; }
|
||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
|
||||
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
|
||||
bool SameShape(const MatPtr& other) const {
|
||||
return Rows() == other.Rows() && cols_ == other.cols_;
|
||||
}
|
||||
// Future calls to `Rows()` during this class' lifetime (not serialized)
|
||||
// will return this value. Used to set the actual number of rows for
|
||||
// activations preallocated according to the batch size.
|
||||
void OverrideRows(size_t rows) {
|
||||
HWY_ASSERT(rows <= private_rows_);
|
||||
override_rows_ = static_cast<uint32_t>(rows);
|
||||
}
|
||||
|
||||
// Offset by which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
|
@ -150,7 +162,7 @@ class MatPtr : public IFields {
|
|||
visitor(type_);
|
||||
visitor(element_bytes_);
|
||||
visitor(num_elements_);
|
||||
visitor(rows_);
|
||||
visitor(private_rows_);
|
||||
visitor(cols_);
|
||||
visitor(scale_);
|
||||
visitor(stride_);
|
||||
|
|
@ -164,11 +176,11 @@ class MatPtr : public IFields {
|
|||
// padding, which is anyway not supported for NUQ because `compress-inl.h`
|
||||
// assumes a contiguous stream for its group indexing.
|
||||
static size_t ComputeNumElements(Type type, Extents2D extents) {
|
||||
const size_t num_elements = extents.Area();
|
||||
size_t num_elements = extents.Area();
|
||||
if (type == Type::kNUQ) {
|
||||
// `CompressedArrayElements` is a wrapper function that has the same
|
||||
// effect, but that requires a template argument, not `type`.
|
||||
return NuqStream::PackedEnd(num_elements);
|
||||
num_elements = NuqStream::PackedEnd(num_elements);
|
||||
}
|
||||
return num_elements;
|
||||
}
|
||||
|
|
@ -184,9 +196,10 @@ class MatPtr : public IFields {
|
|||
// Number of elements to store (including NUQ tables but not padding).
|
||||
// This a function of `type_` and `Extents()` and stored for compatibility.
|
||||
uint32_t num_elements_ = 0;
|
||||
uint32_t rows_ = 0;
|
||||
uint32_t private_rows_ = 0; // Only access via Rows()! See OverrideRows().
|
||||
uint32_t cols_ = 0;
|
||||
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
||||
|
||||
uint32_t override_rows_ = 0; // not serialized
|
||||
|
||||
// Non-owning pointer, must not be freed. The underlying memory must outlive
|
||||
// this object.
|
||||
|
|
@ -194,6 +207,8 @@ class MatPtr : public IFields {
|
|||
|
||||
// Offset by which to advance pointers to the next row, >= `cols_`.
|
||||
uint32_t stride_;
|
||||
|
||||
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
||||
};
|
||||
|
||||
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides
|
||||
|
|
@ -202,6 +217,8 @@ class MatPtr : public IFields {
|
|||
template <typename MatT>
|
||||
class MatPtrT : public MatPtr {
|
||||
public:
|
||||
using T = MatT;
|
||||
|
||||
// Called by `MatStorageT`.
|
||||
MatPtrT(const char* name, Extents2D extents)
|
||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||
|
|
@ -253,26 +270,67 @@ class MatPtrT : public MatPtr {
|
|||
};
|
||||
|
||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||
// optional `args`. Currently unused but may be used after we move toward
|
||||
// type-erased `WeightsPtrs`.
|
||||
// optional `args`. This supports all types used as weights, which excludes
|
||||
// `kC64` and `kF64` (used only in `backprop/`).
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
||||
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
HWY_ASSERT(base != nullptr);
|
||||
if (type == Type::kF32) {
|
||||
return func(dynamic_cast<MatPtrT<float>*>(base),
|
||||
if (base->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (type == Type::kBF16) {
|
||||
return func(dynamic_cast<MatPtrT<BF16>*>(base),
|
||||
} else if (base->GetType() == Type::kBF16) {
|
||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (type == Type::kSFP) {
|
||||
return func(dynamic_cast<MatPtrT<SfpStream>*>(base),
|
||||
} else if (base->GetType() == Type::kSFP) {
|
||||
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (type == Type::kNUQ) {
|
||||
return func(dynamic_cast<MatPtrT<NuqStream>*>(base),
|
||||
} else if (base->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Type %d unknown.", static_cast<int>(type));
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
}
|
||||
|
||||
// Calls `func(base1, base2, args...)`.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||
const Func& func, Args&&... args) {
|
||||
HWY_ASSERT(base1->GetType() == base2->GetType());
|
||||
if (base1->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base1),
|
||||
dynamic_cast<const MatPtrT<float>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (base1->GetType() == Type::kBF16) {
|
||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base1),
|
||||
dynamic_cast<const MatPtrT<BF16>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (base1->GetType() == Type::kSFP) {
|
||||
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (base1->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
|
||||
}
|
||||
}
|
||||
|
||||
// Like CallUpcasted, but only for activation types: kBF16 and kF32.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
HWY_ASSERT(base != nullptr);
|
||||
if (base->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kBF16) {
|
||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -362,8 +420,11 @@ class MatStorageT : public MatPtrT<MatT> {
|
|||
public:
|
||||
MatStorageT(const char* name, Extents2D extents, MatPadding padding)
|
||||
: MatPtrT<MatT>(name, extents) {
|
||||
owner_.AllocateFor(*this, padding);
|
||||
if (extents.Area() != 0) owner_.AllocateFor(*this, padding);
|
||||
}
|
||||
// Shorthand for 1D tensors: packing does not help, hence `kPacked`.
|
||||
MatStorageT(const char* name, size_t cols)
|
||||
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {}
|
||||
~MatStorageT() = default;
|
||||
|
||||
// Allow move for backprop/activations.
|
||||
|
|
@ -467,81 +528,14 @@ using RowPtrBF = RowPtr<BF16>;
|
|||
using RowPtrF = RowPtr<float>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||
// the memory. Unlike `MatPtr`, this lacks metadata.
|
||||
// TODO: replace with `MatStorageT`.
|
||||
// TODO: remove allocator arg once kCyclic is removed.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() = default;
|
||||
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
||||
// we default to tightly packed rows (`stride = cols`).
|
||||
// WARNING: not all call sites support `stride` != cols.
|
||||
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
||||
RowVectorBatch(const Allocator& allocator, Extents2D extents,
|
||||
size_t stride = 0)
|
||||
: extents_(extents) {
|
||||
if (stride == 0) {
|
||||
stride_ = extents_.cols;
|
||||
} else {
|
||||
HWY_ASSERT(stride >= extents_.cols);
|
||||
stride_ = stride;
|
||||
}
|
||||
// Allow binding the entire matrix.
|
||||
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
|
||||
allocator.QuantumBytes() / sizeof(T));
|
||||
mem_ = allocator.Alloc<T>(padded);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return extents_.rows; }
|
||||
size_t Cols() const { return extents_.cols; }
|
||||
size_t Stride() const { return stride_; }
|
||||
Extents2D Extents() const { return extents_; }
|
||||
|
||||
// Returns the given row vector of length `Cols()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * stride_;
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * stride_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
// TODO: remove once we only use Mat.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
AlignedPtr2<T[]> mem_;
|
||||
Extents2D extents_;
|
||||
size_t stride_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromBatch(const Allocator& allocator,
|
||||
RowVectorBatch<T>& row_vectors) {
|
||||
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
|
||||
row_vectors.Stride());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AllocateAlignedRows(const Allocator& allocator,
|
||||
Extents2D extents) {
|
||||
return RowVectorBatch<T>(
|
||||
allocator, extents,
|
||||
StrideForCyclicOffsets(extents.cols,
|
||||
allocator.QuantumBytes() / sizeof(T)));
|
||||
RowPtr<T> RowPtrFromMat(const Allocator& allocator,
|
||||
const MatPtrT<T>& row_vectors) {
|
||||
// RowPtr is non-const for MatMul C, but is also used for A which is const.
|
||||
// Callers are responsible for checking their usage of RowPtr.
|
||||
return RowPtr<T>(allocator, const_cast<T*>(row_vectors.Row(0)),
|
||||
row_vectors.Cols(), row_vectors.Stride());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue