mirror of https://github.com/google/gemma.cpp.git
1.07x batch decode speedup: more BF16 weights and activations
BF16 att_sums and ffw_out Support BF16 B views without decompression Support arbitrary types in MulByConstAndAdd, AddFrom Also update profiler annotations in ops-inl.h PiperOrigin-RevId: 766995010
This commit is contained in:
parent
839a642992
commit
9efdcfd45c
|
|
@ -214,6 +214,7 @@ cc_library(
|
||||||
hdrs = ["gemma/weights.h"],
|
hdrs = ["gemma/weights.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":configs",
|
":configs",
|
||||||
|
":gemma_args",
|
||||||
":mat",
|
":mat",
|
||||||
":model_store",
|
":model_store",
|
||||||
":ops",
|
":ops",
|
||||||
|
|
@ -223,9 +224,7 @@ cc_library(
|
||||||
"//io:blob_store",
|
"//io:blob_store",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@highway//:stats",
|
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@highway//:timer",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -500,7 +499,6 @@ cc_library(
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//io:blob_store",
|
"//io:blob_store",
|
||||||
"//io",
|
"//io",
|
||||||
"//compression:types",
|
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:nanobenchmark", # timer
|
"@highway//:nanobenchmark", # timer
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference)
|
const InferenceArgs& inference)
|
||||||
: env_(MakeMatMulEnv(threading)), gemma_(loader, env_) {
|
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
|
||||||
const ModelConfig& config = gemma_.GetModelConfig();
|
const ModelConfig& config = gemma_.GetModelConfig();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
|
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
|
||||||
|
|
@ -229,8 +229,9 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
threading.Print(inference.verbosity);
|
threading.Print(inference.verbosity);
|
||||||
loader.Print(inference.verbosity);
|
loader.Print(inference.verbosity);
|
||||||
inference.Print(inference.verbosity);
|
inference.Print(inference.verbosity);
|
||||||
fprintf(stderr, "Model : %s, mmap %d\n",
|
fprintf(stderr, "Model : %s, to_bf16 %d, mmap %d\n",
|
||||||
config.Specifier().c_str(), static_cast<int>(loader.map));
|
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
|
||||||
|
static_cast<int>(loader.map));
|
||||||
|
|
||||||
if (inference.verbosity >= 2) {
|
if (inference.verbosity >= 2) {
|
||||||
time_t now = time(nullptr);
|
time_t now = time(nullptr);
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::string> BatchGemmaReply(
|
std::vector<std::string> BatchGemmaReply(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& inputs) {
|
||||||
s_env->SetMaxGeneratedTokens(32);
|
s_env->SetMaxGeneratedTokens(16);
|
||||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
s_env->MutableConfig().verbosity = 2;
|
s_env->MutableConfig().verbosity = 2;
|
||||||
std::vector<std::string> replies;
|
std::vector<std::string> replies;
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||||
gcpp::Gemma gemma(loader, env);
|
gcpp::Gemma gemma(loader, inference, env);
|
||||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class SimplifiedGemma {
|
||||||
threading_(threading),
|
threading_(threading),
|
||||||
inference_(inference),
|
inference_(inference),
|
||||||
env_(MakeMatMulEnv(threading_)),
|
env_(MakeMatMulEnv(threading_)),
|
||||||
gemma_(loader_, env_),
|
gemma_(loader_, inference_, env_),
|
||||||
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
|
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
|
|
|
||||||
|
|
@ -88,23 +88,14 @@ struct Activations {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
|
|
||||||
if (!mat.HasPtr()) return;
|
|
||||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
|
|
||||||
uint8_t** ptrs = row_ptrs.back().get();
|
|
||||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
|
||||||
ptrs[r] = mat.RowBytes(r);
|
|
||||||
}
|
|
||||||
mat.AttachRowPtrs(ptrs);
|
|
||||||
};
|
|
||||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||||
// fill them in each MatMul call.
|
// fill them in each MatMul call.
|
||||||
init_row_ptrs(q);
|
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
init_row_ptrs(logits);
|
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
init_row_ptrs(att_sums);
|
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
init_row_ptrs(C1);
|
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
init_row_ptrs(C2);
|
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
init_row_ptrs(ffw_out);
|
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
// TODO: also init rows for image_tokens.
|
// TODO: also init rows for image_tokens.
|
||||||
|
|
||||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||||
|
|
@ -142,7 +133,7 @@ struct Activations {
|
||||||
const MatPadding pad_ = MatPadding::kOdd;
|
const MatPadding pad_ = MatPadding::kOdd;
|
||||||
|
|
||||||
MatStorageT<float> x; // input
|
MatStorageT<float> x; // input
|
||||||
MatStorageT<float> q; // query, also KV if MHA.
|
MatStorageT<float> q; // query
|
||||||
MatStorageT<float> logits;
|
MatStorageT<float> logits;
|
||||||
|
|
||||||
// Attention
|
// Attention
|
||||||
|
|
@ -150,13 +141,13 @@ struct Activations {
|
||||||
MatStorageT<float> att; // attention vector
|
MatStorageT<float> att; // attention vector
|
||||||
MatStorageT<float> att_out; // attention output
|
MatStorageT<float> att_out; // attention output
|
||||||
// Accumulation of attention outputs over heads
|
// Accumulation of attention outputs over heads
|
||||||
MatStorageT<float> att_sums;
|
MatStorageT<BF16> att_sums;
|
||||||
|
|
||||||
// Gated FFW
|
// Gated FFW
|
||||||
MatStorageT<BF16> pre_ffw_rms_out;
|
MatStorageT<BF16> pre_ffw_rms_out;
|
||||||
MatStorageT<float> C1;
|
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
|
||||||
MatStorageT<float> C2;
|
MatStorageT<float> C2;
|
||||||
MatStorageT<float> ffw_out;
|
MatStorageT<BF16> ffw_out;
|
||||||
|
|
||||||
// Griffin
|
// Griffin
|
||||||
MatStorageT<float> griffin_x;
|
MatStorageT<float> griffin_x;
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
threading_args(threading_args),
|
threading_args(threading_args),
|
||||||
matmul_env(MakeMatMulEnv(threading_args)),
|
matmul_env(MakeMatMulEnv(threading_args)),
|
||||||
active_conversation_name("default"),
|
active_conversation_name("default"),
|
||||||
model(loader, matmul_env) {
|
model(loader, inference_args, matmul_env) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
LogDebug("Creating initial ConversationData");
|
LogDebug("Creating initial ConversationData");
|
||||||
|
|
|
||||||
|
|
@ -210,17 +210,9 @@ static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
||||||
} // interleaved_idx
|
} // interleaved_idx
|
||||||
|
|
||||||
// Final linear layer.
|
// Final linear layer.
|
||||||
// TODO: MatMul
|
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
|
||||||
CallUpcasted(
|
|
||||||
&layer_weights->griffin.linear_out_w, [&](const auto* weights_t) {
|
|
||||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
|
||||||
float* out_ptr = activations.att_sums.Row(r);
|
|
||||||
MatVecAdd(*weights_t, 0, model_dim, model_dim, x,
|
|
||||||
layer_weights->griffin.linear_out_biases.PackedScale1(),
|
layer_weights->griffin.linear_out_biases.PackedScale1(),
|
||||||
out_ptr, pool);
|
*activations.env, activations.att_sums);
|
||||||
}
|
|
||||||
});
|
|
||||||
} // GriffinRecurrent
|
} // GriffinRecurrent
|
||||||
|
|
||||||
// Wrapper class; holds arguments in member variables to shorten call sites.
|
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||||
|
|
@ -1142,7 +1134,10 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// Add position embeddings.
|
// Add position embeddings.
|
||||||
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
|
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
||||||
|
[&](const auto* weights_t) {
|
||||||
|
AddFromBatched(*weights_t, activations.x);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefills the image tokens with the ViT encoder.
|
// Prefills the image tokens with the ViT encoder.
|
||||||
|
|
|
||||||
|
|
@ -68,13 +68,14 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
||||||
return MatMulEnv(ThreadingContext::Get());
|
return MatMulEnv(ThreadingContext::Get());
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
|
MatMulEnv& env)
|
||||||
: env_(env),
|
: env_(env),
|
||||||
reader_(loader.weights),
|
reader_(loader.weights),
|
||||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||||
weights_(model_.Config()),
|
weights_(model_.Config()),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||||
weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_,
|
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
||||||
env.ctx.pools.Pool());
|
env.ctx.pools.Pool());
|
||||||
reader_.CloseFile();
|
reader_.CloseFile();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,8 @@ class Gemma {
|
||||||
public:
|
public:
|
||||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||||
// `env` must remain valid for the lifetime of this Gemma.
|
// `env` must remain valid for the lifetime of this Gemma.
|
||||||
Gemma(const LoaderArgs& loader, MatMulEnv& env);
|
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
|
MatMulEnv& env);
|
||||||
|
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
Path tokenizer;
|
Path tokenizer;
|
||||||
Path weights; // weights file location
|
Path weights; // weights file location
|
||||||
Tristate map;
|
Tristate map;
|
||||||
|
Tristate to_bf16;
|
||||||
Tristate wrapping;
|
Tristate wrapping;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
|
|
@ -62,6 +63,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||||
visitor(map, "map", Tristate::kDefault,
|
visitor(map, "map", Tristate::kDefault,
|
||||||
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
|
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
|
||||||
|
visitor(to_bf16, "to_bf16", Tristate::kDefault,
|
||||||
|
"Convert weights to bf16? -1 = auto, 0 = no, 1 = yes.");
|
||||||
visitor(wrapping, "wrapping", Tristate::kDefault,
|
visitor(wrapping, "wrapping", Tristate::kDefault,
|
||||||
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
|
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
|
|
||||||
MatMulEnv env(MakeMatMulEnv(threading));
|
MatMulEnv env(MakeMatMulEnv(threading));
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
const Gemma gemma(loader, env);
|
const Gemma gemma(loader, inference, env);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
|
|
|
||||||
218
gemma/weights.cc
218
gemma/weights.cc
|
|
@ -27,6 +27,7 @@
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "compression/types.h"
|
#include "compression/types.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/gemma_args.h"
|
||||||
#include "gemma/model_store.h"
|
#include "gemma/model_store.h"
|
||||||
#include "io/blob_store.h"
|
#include "io/blob_store.h"
|
||||||
#include "ops/matmul.h" // MMParallel
|
#include "ops/matmul.h" // MMParallel
|
||||||
|
|
@ -37,7 +38,7 @@
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
// TODO: move into foreach_target; this is only used for NUQ Fixup.
|
// TODO: move into foreach_target
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -246,16 +247,80 @@ std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
|
||||||
return serialized_mat_ptrs;
|
return serialized_mat_ptrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum class Mode {
|
||||||
|
// Parallel I/O, decompress to BF16. Best for large batch sizes.
|
||||||
|
kReadBF16,
|
||||||
|
// Parallel I/O, insert row-wise padding. Safe default.
|
||||||
|
kRead,
|
||||||
|
// Best for large weights relative to available memory, especially for
|
||||||
|
// frequent invocations of small batches and short sequences. Adds noise to
|
||||||
|
// performance measurements due to I/O variability.
|
||||||
|
kMap
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decides whether to read or map based on heuristics and user override.
|
||||||
|
static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader,
|
||||||
|
const InferenceArgs& inference) {
|
||||||
|
Tristate to_bf16 = loader.to_bf16;
|
||||||
|
Tristate map = loader.map;
|
||||||
|
|
||||||
|
// Disable mapping if not padded to the base page size.
|
||||||
|
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||||
|
if (file_bytes % allocator.BasePageBytes() != 0) {
|
||||||
|
if (map != Tristate::kFalse) { // Do not complain if anyway disabled.
|
||||||
|
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
||||||
|
static_cast<size_t>(file_bytes >> 10),
|
||||||
|
allocator.BasePageBytes());
|
||||||
|
map = Tristate::kFalse;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for user override:
|
||||||
|
if (to_bf16 == Tristate::kTrue && map == Tristate::kTrue) {
|
||||||
|
HWY_WARN("Cannot have to_bf16 && map, to_bf16 takes precedence.");
|
||||||
|
}
|
||||||
|
if (to_bf16 == Tristate::kTrue) return Mode::kReadBF16;
|
||||||
|
if (map == Tristate::kTrue) return Mode::kMap;
|
||||||
|
|
||||||
|
if (to_bf16 == Tristate::kDefault) {
|
||||||
|
// Heuristic: sub-bf16 compression is not helpful if compute-bound.
|
||||||
|
const size_t batch_size =
|
||||||
|
HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size);
|
||||||
|
to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (map == Tristate::kDefault) {
|
||||||
|
// Heuristic: map if large fraction of total. Do not decide based on
|
||||||
|
// `FreeMiB` because it is generally low.
|
||||||
|
const size_t file_mib = file_bytes >> 20;
|
||||||
|
const size_t total_mib = allocator.TotalMiB();
|
||||||
|
if (file_mib > total_mib) {
|
||||||
|
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
||||||
|
static_cast<size_t>(file_mib), total_mib);
|
||||||
|
}
|
||||||
|
// Large fraction of total.
|
||||||
|
map = (file_mib >= total_mib / 3) ? Tristate::kTrue : Tristate::kFalse;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the `map` heuristic triggers, use that for safety.
|
||||||
|
if (map == Tristate::kTrue) return Mode::kMap;
|
||||||
|
return (to_bf16 == Tristate::kTrue) ? Mode::kReadBF16 : Mode::kRead;
|
||||||
|
}
|
||||||
|
|
||||||
struct TensorToRead {
|
struct TensorToRead {
|
||||||
MatPtr* mat;
|
MatPtr* mat;
|
||||||
BlobRange range;
|
BlobRange range;
|
||||||
// Some tensors opt out of padding via kPacked flags.
|
// Some tensors opt out of padding via kPacked flags.
|
||||||
MatPadding padding;
|
MatPadding padding;
|
||||||
|
|
||||||
|
// only for kReadBF16
|
||||||
|
bool keep_type = false;
|
||||||
|
Type prev_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allocates multiple in parallel and binds to NUMA nodes.
|
// Allocates multiple in parallel and binds to NUMA nodes.
|
||||||
static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
|
static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||||
std::vector<MatOwner>& owners,
|
const Mode mode, std::vector<MatOwner>& owners,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
const size_t start = owners.size();
|
const size_t start = owners.size();
|
||||||
owners.resize(start + tensors.size());
|
owners.resize(start + tensors.size());
|
||||||
|
|
@ -264,52 +329,25 @@ static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
|
||||||
|
|
||||||
// Allocate in parallel because faulting in large tensors is slow.
|
// Allocate in parallel because faulting in large tensors is slow.
|
||||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||||
owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
|
TensorToRead& tensor = tensors[task];
|
||||||
|
MatPtr& mat = *tensor.mat;
|
||||||
|
|
||||||
|
tensor.prev_type = mat.GetType();
|
||||||
|
// We only care about MatMul inputs; skip F32 or small tensors.
|
||||||
|
if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) {
|
||||||
|
tensor.keep_type = true;
|
||||||
|
tensor.padding = MatPadding::kPacked; // single I/O for simplicity
|
||||||
|
} else if (mode == Mode::kReadBF16) {
|
||||||
|
mat.SetType(Type::kBF16);
|
||||||
|
}
|
||||||
|
|
||||||
|
owners[start + task].AllocateFor(*tensor.mat, tensor.padding);
|
||||||
// TODO(janwas): MatMul outputs will later also be BF16.
|
// TODO(janwas): MatMul outputs will later also be BF16.
|
||||||
BindB(*tensors[task].mat, sizeof(float), parallel);
|
BindB(*tensor.mat, sizeof(float), parallel);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
// Mode == kMap
|
||||||
// better when the file is huge, but page faults add noise to measurements.
|
|
||||||
enum class Mode { kRead, kMap };
|
|
||||||
|
|
||||||
// Decides whether to read or map based on heuristics and user override.
|
|
||||||
static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
|
|
||||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
|
||||||
// User has explicitly requested a map or read via args.
|
|
||||||
if (map == Tristate::kTrue) return Mode::kMap;
|
|
||||||
if (map == Tristate::kFalse) return Mode::kRead;
|
|
||||||
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
|
|
||||||
// because idle memory is used as cache, so do not use it to decide.
|
|
||||||
const size_t file_mib = file_bytes >> 20;
|
|
||||||
const size_t total_mib = allocator.TotalMiB();
|
|
||||||
if (file_mib > total_mib) {
|
|
||||||
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
|
||||||
static_cast<size_t>(file_mib), total_mib);
|
|
||||||
}
|
|
||||||
// Large fraction of total.
|
|
||||||
if (file_mib >= total_mib / 3) return Mode::kMap;
|
|
||||||
// Big enough that even parallel loading wouldn't be quick.
|
|
||||||
if (file_mib > 50 * 1024) return Mode::kMap;
|
|
||||||
return Mode::kRead;
|
|
||||||
}
|
|
||||||
|
|
||||||
static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
|
||||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
|
||||||
if (file_bytes % allocator.BasePageBytes() == 0) {
|
|
||||||
MapPtr mapped = file.Map();
|
|
||||||
if (!mapped) {
|
|
||||||
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
|
||||||
static_cast<size_t>(file_bytes >> 10));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
|
||||||
static_cast<size_t>(file_bytes >> 10), allocator.BasePageBytes());
|
|
||||||
}
|
|
||||||
return MapPtr();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void MapAll(const std::vector<TensorToRead>& tensors,
|
static void MapAll(const std::vector<TensorToRead>& tensors,
|
||||||
const MapPtr& mapped) {
|
const MapPtr& mapped) {
|
||||||
PROFILER_ZONE("Startup.Weights.Map");
|
PROFILER_ZONE("Startup.Weights.Map");
|
||||||
|
|
@ -326,6 +364,65 @@ static void MapAll(const std::vector<TensorToRead>& tensors,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mode == kReadBF16:
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void DecompressToBF16(const MatPtr& mat,
|
||||||
|
const hwy::AlignedFreeUniquePtr<uint8_t[]>& buf) {
|
||||||
|
hwy::HWY_NAMESPACE::ScalableTag<BF16> dbf;
|
||||||
|
const size_t cols = mat.Cols();
|
||||||
|
|
||||||
|
const size_t num_packed = CompressedArrayElements<T>(mat.Extents().Area());
|
||||||
|
const PackedSpan<T> packed{HWY_RCAST_ALIGNED(T*, buf.get()), num_packed};
|
||||||
|
|
||||||
|
size_t packed_ofs = 0;
|
||||||
|
for (size_t r = 0; r < mat.Rows(); ++r, packed_ofs += cols) {
|
||||||
|
HWY_NAMESPACE::DecompressAndZeroPad(
|
||||||
|
dbf, packed, packed_ofs, HWY_RCAST_ALIGNED(BF16*, mat.RowBytes(r)),
|
||||||
|
cols);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||||
|
const BlobReader& reader, hwy::ThreadPool& pool) {
|
||||||
|
PROFILER_ZONE("Startup.Weights.ReadBF16");
|
||||||
|
|
||||||
|
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||||
|
const TensorToRead& tensor = tensors[task];
|
||||||
|
MatPtr& mat = *tensor.mat;
|
||||||
|
|
||||||
|
if (tensor.keep_type) {
|
||||||
|
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
|
||||||
|
mat.Packed()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read to a temporary buffer.
|
||||||
|
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
|
||||||
|
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
|
||||||
|
HWY_ASSERT(
|
||||||
|
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
|
||||||
|
|
||||||
|
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||||
|
if (tensor.prev_type == Type::kNUQ) {
|
||||||
|
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch (tensor.prev_type) {
|
||||||
|
case Type::kF32:
|
||||||
|
return DecompressToBF16<float>(*tensor.mat, buf);
|
||||||
|
case Type::kBF16:
|
||||||
|
return DecompressToBF16<BF16>(*tensor.mat, buf);
|
||||||
|
case Type::kSFP:
|
||||||
|
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
|
||||||
|
default:
|
||||||
|
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode == kRead:
|
||||||
|
|
||||||
static std::vector<IOBatch> MakeBatches(
|
static std::vector<IOBatch> MakeBatches(
|
||||||
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
|
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
|
||||||
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||||
|
|
@ -382,31 +479,36 @@ static void ReadBatches(const BlobReader& reader,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Aborts on error.
|
// Aborts on error.
|
||||||
static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
|
static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
||||||
BlobReader& reader, Tristate map,
|
Mode mode, std::vector<MatOwner>& mat_owners,
|
||||||
std::vector<MatOwner>& mat_owners,
|
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
if (mode == Mode::kMap) {
|
||||||
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
|
MapPtr mapped = reader.file().Map();
|
||||||
if (mapped) {
|
if (mapped) return MapAll(tensors, mapped);
|
||||||
MapAll(tensors, mapped);
|
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
||||||
return;
|
static_cast<size_t>(reader.file_bytes() >> 10));
|
||||||
|
// If we wanted to map but failed, memory is probably not plentiful, so
|
||||||
|
// fall through to kRead because kReadBF16 requires more memory.
|
||||||
|
mode = Mode::kRead;
|
||||||
}
|
}
|
||||||
} // otherwise fall through to read mode
|
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.Weights.Allocate");
|
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||||
// NOTE: this changes the stride of `mats`!
|
// NOTE: this changes the stride of `mats`!
|
||||||
AllocateAndBindAll(tensors, mat_owners, pool);
|
AllocateAndBindAll(tensors, mode, mat_owners, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool);
|
||||||
|
|
||||||
const std::vector<IOBatch> batches =
|
const std::vector<IOBatch> batches =
|
||||||
MakeBatches(tensors, reader.file_bytes());
|
MakeBatches(tensors, reader.file_bytes());
|
||||||
ReadBatches(reader, batches, pool);
|
ReadBatches(reader, batches, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
||||||
BlobReader& reader, Tristate map,
|
BlobReader& reader,
|
||||||
|
const LoaderArgs& loader,
|
||||||
|
const InferenceArgs& inference,
|
||||||
std::vector<MatOwner>& mat_owners,
|
std::vector<MatOwner>& mat_owners,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
|
|
@ -427,7 +529,9 @@ void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
||||||
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
||||||
});
|
});
|
||||||
|
|
||||||
MapOrReadAll(tensors, reader, map, mat_owners, pool);
|
const Mode mode = ChooseMode(reader.file_bytes(), loader, inference);
|
||||||
|
|
||||||
|
MapOrReadAll(tensors, reader, mode, mat_owners, pool);
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.Fixup");
|
PROFILER_ZONE("Startup.Fixup");
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
|
|
||||||
#include "compression/types.h"
|
#include "compression/types.h"
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
|
#include "gemma/gemma_args.h" // InferenceArgs
|
||||||
#include "gemma/model_store.h" // ModelStore
|
#include "gemma/model_store.h" // ModelStore
|
||||||
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
||||||
#include "io/blob_store.h" // BlobWriter
|
#include "io/blob_store.h" // BlobWriter
|
||||||
|
|
@ -424,7 +425,8 @@ struct ModelWeightsPtrs {
|
||||||
|
|
||||||
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
||||||
// override for whether to map blobs or read them.
|
// override for whether to map blobs or read them.
|
||||||
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
|
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
|
const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,6 @@ cc_library(
|
||||||
"//:basics",
|
"//:basics",
|
||||||
"//:threading_context",
|
"//:threading_context",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,6 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#pragma push_macro("PROFILER_ENABLED")
|
|
||||||
#undef PROFILER_ENABLED
|
|
||||||
#define PROFILER_ENABLED 0
|
|
||||||
|
|
||||||
#include "compression/types.h"
|
#include "compression/types.h"
|
||||||
#include "ops/matmul.h" // IWYU pragma: export
|
#include "ops/matmul.h" // IWYU pragma: export
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
|
@ -952,15 +948,11 @@ class MMPerPackage {
|
||||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
const RowPtrBF B_storage_view(B_storage, K, B_stride);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
{
|
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
|
||||||
MMZone zone;
|
|
||||||
zone.MaybeEnter("MM.NT.DecB", args_);
|
|
||||||
DecompressB(B, row_b, range_K, B_view);
|
|
||||||
}
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||||
args_, C_rows);
|
args_, C_rows);
|
||||||
}
|
}
|
||||||
|
|
@ -985,17 +977,13 @@ class MMPerPackage {
|
||||||
auto out_tag) HWY_ATTR {
|
auto out_tag) HWY_ATTR {
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||||
const RowPtrBF B_view(
|
const RowPtrBF B_storage_view(
|
||||||
B_storage, kc,
|
B_storage, kc,
|
||||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
{
|
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||||
MMZone zone;
|
|
||||||
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
|
||||||
DecompressB(B, row_b, range_kc, B_view);
|
|
||||||
}
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
C_rows);
|
C_rows);
|
||||||
}
|
}
|
||||||
|
|
@ -1051,15 +1039,11 @@ class MMPerPackage {
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
const RowPtrBF B_storage_view(B_storage, K, B_stride);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
{
|
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
|
||||||
MMZone zone;
|
|
||||||
zone.MaybeEnter("MM.NT_MT.DecB", args_);
|
|
||||||
DecompressB(B, row_b, range_K, B_view);
|
|
||||||
}
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||||
args_, C_rows);
|
args_, C_rows);
|
||||||
}
|
}
|
||||||
|
|
@ -1081,7 +1065,8 @@ class MMPerPackage {
|
||||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||||
const auto loop_nc = [&](const RowPtrBF& B_view, const IndexRange& range_mc,
|
const auto loop_nc = [&](const RowPtrBF& B_storage_view,
|
||||||
|
const IndexRange& range_mc,
|
||||||
const IndexRange& range_kc,
|
const IndexRange& range_kc,
|
||||||
const IndexRange& range_nc,
|
const IndexRange& range_nc,
|
||||||
auto out_tag) HWY_ATTR {
|
auto out_tag) HWY_ATTR {
|
||||||
|
|
@ -1090,11 +1075,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
{
|
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||||
MMZone zone;
|
|
||||||
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
|
||||||
DecompressB(B, row_b, range_kc, B_view);
|
|
||||||
}
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
C_rows);
|
C_rows);
|
||||||
}
|
}
|
||||||
|
|
@ -1103,15 +1084,17 @@ class MMPerPackage {
|
||||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, kc_max, B_stride);
|
const RowPtrBF B_storage_view(B_storage, kc_max, B_stride);
|
||||||
|
|
||||||
// Peel off the first iteration of the kc loop: avoid
|
// Peel off the first iteration of the kc loop: avoid
|
||||||
// zero-initializing `partial` by writing into it.
|
// zero-initializing `partial` by writing into it.
|
||||||
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
|
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
|
||||||
loop_nc(B_view, range_mc, range_kc, range_nc, MMSetPartial());
|
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
|
||||||
|
MMSetPartial());
|
||||||
});
|
});
|
||||||
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
|
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
|
||||||
loop_nc(B_view, range_mc, range_kc, range_nc, MMAddPartial());
|
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
|
||||||
|
MMAddPartial());
|
||||||
});
|
});
|
||||||
|
|
||||||
// Already in parallel section, hence no `kParM`, and
|
// Already in parallel section, hence no `kParM`, and
|
||||||
|
|
@ -1228,11 +1211,15 @@ class MMPerPackage {
|
||||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||||
// thanks to its large table lookups, and less so on other targets.
|
// thanks to its large table lookups, and less so on other targets.
|
||||||
template <typename TB>
|
template <typename TB>
|
||||||
HWY_INLINE void DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
HWY_INLINE RowPtrBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||||
const IndexRange& range_kc,
|
const IndexRange& range_kc,
|
||||||
const RowPtrBF& B_view) const {
|
const RowPtrBF& B_view) const {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||||
|
return RowPtrBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
|
||||||
|
range_kc.Num(), B.Stride());
|
||||||
|
}
|
||||||
|
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||||
|
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
|
|
@ -1249,6 +1236,7 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return B_view;
|
||||||
}
|
}
|
||||||
|
|
||||||
const MMArgs args_; // copy for locality
|
const MMArgs args_; // copy for locality
|
||||||
|
|
@ -1421,5 +1409,3 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
HWY_AFTER_NAMESPACE();
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
#endif // NOLINT
|
#endif // NOLINT
|
||||||
|
|
||||||
#pragma pop_macro("PROFILER_ENABLED")
|
|
||||||
|
|
|
||||||
12
ops/matmul.h
12
ops/matmul.h
|
|
@ -567,6 +567,16 @@ class MMAutoTune {
|
||||||
|
|
||||||
// Map of previously seen dimensions to index via linear search.
|
// Map of previously seen dimensions to index via linear search.
|
||||||
class MMKeys {
|
class MMKeys {
|
||||||
|
// Group batch size into buckets to reduce #auto-tunes.
|
||||||
|
static size_t BucketM(size_t M) {
|
||||||
|
// The first 4 may require their own bucket because `kNT` only works for a
|
||||||
|
// single M range, but that depends on the config's `MR()`.
|
||||||
|
if (M <= 4) return M;
|
||||||
|
if (M <= 16) return 16;
|
||||||
|
if (M <= 64) return 64;
|
||||||
|
return 256;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using Key = uint64_t;
|
using Key = uint64_t;
|
||||||
// KeyFromDims will only return this if all dims are zero, which is invalid.
|
// KeyFromDims will only return this if all dims are zero, which is invalid.
|
||||||
|
|
@ -577,7 +587,7 @@ class MMKeys {
|
||||||
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
||||||
HWY_DASSERT(K < (Key{1} << 24));
|
HWY_DASSERT(K < (Key{1} << 24));
|
||||||
HWY_DASSERT(N < (Key{1} << 24));
|
HWY_DASSERT(N < (Key{1} << 24));
|
||||||
const Key key = static_cast<Key>(M) | (static_cast<Key>(K) << 16) |
|
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
|
||||||
(static_cast<Key>(N) << 40);
|
(static_cast<Key>(N) << 40);
|
||||||
HWY_DASSERT(key != kPadding);
|
HWY_DASSERT(key != kPadding);
|
||||||
return key;
|
return key;
|
||||||
|
|
|
||||||
125
ops/ops-inl.h
125
ops/ops-inl.h
|
|
@ -117,6 +117,7 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
PROFILER_ZONE("ops.Gelu");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
hn::Transform(D(), x, size,
|
hn::Transform(D(), x, size,
|
||||||
|
|
@ -181,6 +182,8 @@ namespace detail {
|
||||||
// Shared by RMSNorm and RMSNormInplace.
|
// Shared by RMSNorm and RMSNormInplace.
|
||||||
template <typename VT>
|
template <typename VT>
|
||||||
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormMul");
|
||||||
|
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||||
constexpr float kEps = 1e-6f; // avoid divide by zero
|
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||||
|
|
@ -195,7 +198,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
const WT* HWY_RESTRICT weight,
|
const WT* HWY_RESTRICT weight,
|
||||||
size_t w_ofs, OT* HWY_RESTRICT out,
|
size_t w_ofs, OT* HWY_RESTRICT out,
|
||||||
const size_t size) {
|
const size_t size) {
|
||||||
PROFILER_FUNC;
|
PROFILER_ZONE("ops.RMSNorm");
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -228,7 +231,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
||||||
size_t w_ofs,
|
size_t w_ofs,
|
||||||
XT* HWY_RESTRICT inout,
|
XT* HWY_RESTRICT inout,
|
||||||
const size_t size) {
|
const size_t size) {
|
||||||
PROFILER_FUNC;
|
PROFILER_ZONE("ops.RMSNormInplace");
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -280,7 +283,7 @@ HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size,
|
||||||
template <typename XT, typename WT, typename OT>
|
template <typename XT, typename WT, typename OT>
|
||||||
HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
||||||
const WT* HWY_RESTRICT bias, OT* out, size_t size) {
|
const WT* HWY_RESTRICT bias, OT* out, size_t size) {
|
||||||
PROFILER_FUNC;
|
PROFILER_ZONE("ops.LayerNorm");
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -360,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* RoPE as in Rotary Position Embeddings from the RoFormer paper
|
/* RoPE as in Rotary Position Embeddings from the `RoFormer` paper
|
||||||
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
|
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
|
||||||
as a function of their absolute position using the rotation matrix R before
|
as a function of their absolute position using the rotation matrix R before
|
||||||
the self-attention operation. R is a d x d matrix.
|
the self-attention operation. R is a d x d matrix.
|
||||||
|
|
@ -391,7 +394,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||||
PROFILER_FUNC;
|
PROFILER_ZONE("ops.Rope");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
|
|
@ -409,7 +412,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||||
PROFILER_FUNC;
|
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -477,15 +480,45 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
template <typename XT>
|
||||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
float* HWY_RESTRICT out,
|
||||||
using D = hn::ScalableTag<float>;
|
const size_t size) {
|
||||||
using V = hn::Vec<D>;
|
PROFILER_ZONE("ops.AddFrom");
|
||||||
|
|
||||||
hn::Transform1(D(), x, size, other,
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
[](const auto d, const V x, const V other)
|
const hn::ScalableTag<float> df;
|
||||||
HWY_ATTR { return hn::Add(x, other); });
|
const size_t NF = hn::Lanes(df);
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
const auto packed_x = MakeSpan(x, size);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
if (size >= 2 * NF) {
|
||||||
|
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||||
|
VF x0, x1;
|
||||||
|
Decompress2(df, packed_x, i, x0, x1);
|
||||||
|
VF out0 = hn::Load(df, out + i);
|
||||||
|
VF out1 = hn::Load(df, out + i + NF);
|
||||||
|
hn::Store(hn::Add(x0, out0), df, out + i);
|
||||||
|
hn::Store(hn::Add(x1, out1), df, out + i + NF);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t remaining = size - i;
|
||||||
|
const size_t remaining1 = remaining - HWY_MIN(remaining, NF);
|
||||||
|
HWY_DASSERT(remaining < 2 * NF);
|
||||||
|
HWY_DASSERT(remaining1 < NF);
|
||||||
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
|
const VF x0 = hn::Load(df, buf_x);
|
||||||
|
const VF x1 = hn::Load(df, buf_x + NF);
|
||||||
|
const VF out0 = hn::LoadN(df, out + i, remaining);
|
||||||
|
const VF out1 = hn::LoadN(df, out + i + NF, remaining1);
|
||||||
|
hn::StoreN(hn::Add(x0, out0), df, out + i, remaining);
|
||||||
|
hn::StoreN(hn::Add(x1, out1), df, out + i + NF, remaining1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||||
|
|
@ -534,17 +567,19 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE void AddFromBatched(const MatPtrT<float>& other,
|
template <typename XT>
|
||||||
MatPtrT<float>& x) {
|
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
|
||||||
HWY_DASSERT(x.SameShape(other));
|
MatPtrT<float>& out) {
|
||||||
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
|
HWY_DASSERT(out.SameShape(x));
|
||||||
AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols());
|
for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) {
|
||||||
|
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||||
float* HWY_RESTRICT x, const size_t size,
|
float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t max_pos) {
|
const size_t max_pos) {
|
||||||
|
PROFILER_ZONE("ops.MulBy");
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -563,6 +598,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
||||||
|
|
||||||
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
||||||
const size_t size, const size_t max_pos) {
|
const size_t size, const size_t max_pos) {
|
||||||
|
PROFILER_ZONE("ops.MulByConst");
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -578,22 +614,56 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
||||||
MulByConst(c, x, size, size);
|
MulByConst(c, x, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
template <typename XT, typename OT>
|
||||||
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
||||||
|
const XT* HWY_RESTRICT x,
|
||||||
|
OT* HWY_RESTRICT out,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
PROFILER_ZONE("ops.MulByConstAndAdd");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
const hn::ScalableTag<float> df;
|
||||||
using V = hn::Vec<D>;
|
const size_t NF = hn::Lanes(df);
|
||||||
hn::Transform1(D(), out, size, x,
|
using VF = hn::Vec<decltype(df)>;
|
||||||
[c](const auto d, const V v_out, const V v_x) HWY_ATTR {
|
|
||||||
return hn::MulAdd(v_x, hn::Set(d, c), v_out);
|
const VF v_c = hn::Set(df, c);
|
||||||
});
|
const auto packed_x = MakeSpan(x, size);
|
||||||
|
const auto packed_out = MakeSpan(out, size);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
if (size >= 2 * NF) {
|
||||||
|
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||||
|
VF x0, x1, out0, out1;
|
||||||
|
Decompress2(df, packed_x, i, x0, x1);
|
||||||
|
Decompress2(df, packed_out, i, out0, out1);
|
||||||
|
out0 = hn::MulAdd(x0, v_c, out0);
|
||||||
|
out1 = hn::MulAdd(x1, v_c, out1);
|
||||||
|
Compress2(df, out0, out1, packed_out, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t remaining = size - i;
|
||||||
|
HWY_DASSERT(remaining < 2 * NF);
|
||||||
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
|
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
|
||||||
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
|
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
|
||||||
|
const VF x0 = hn::Load(df, buf_x);
|
||||||
|
const VF x1 = hn::Load(df, buf_x + NF);
|
||||||
|
VF out0 = hn::Load(df, buf_out);
|
||||||
|
VF out1 = hn::Load(df, buf_out + NF);
|
||||||
|
out0 = hn::MulAdd(x0, v_c, out0);
|
||||||
|
out1 = hn::MulAdd(x1, v_c, out1);
|
||||||
|
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
|
||||||
|
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos,
|
const size_t mask_pos,
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f) {
|
||||||
|
PROFILER_ZONE("ops.Softmax");
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
|
||||||
|
|
@ -733,6 +803,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
const size_t size,
|
const size_t size,
|
||||||
const size_t max_pos) {
|
const size_t max_pos) {
|
||||||
|
PROFILER_ZONE("ops.LogitsSoftCap");
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
|
||||||
22
util/mat.h
22
util/mat.h
|
|
@ -96,6 +96,17 @@ class MatPtr : public IFields {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AllocateAndAttachRowPtrs(
|
||||||
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) {
|
||||||
|
if (!HasPtr()) return;
|
||||||
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(Rows()));
|
||||||
|
uint8_t** ptrs = row_ptrs.back().get();
|
||||||
|
for (size_t r = 0; r < Rows(); ++r) {
|
||||||
|
ptrs[r] = RowBytes(r);
|
||||||
|
}
|
||||||
|
AttachRowPtrs(ptrs);
|
||||||
|
};
|
||||||
|
|
||||||
uint8_t** GetRowPtrs() const { return row_ptrs_; }
|
uint8_t** GetRowPtrs() const { return row_ptrs_; }
|
||||||
|
|
||||||
// A single row counts as packed because there is no padding between rows.
|
// A single row counts as packed because there is no padding between rows.
|
||||||
|
|
@ -328,7 +339,7 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||||
template <class Func, typename... Args>
|
template <class Func, typename... Args>
|
||||||
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||||
const Func& func, Args&&... args) {
|
const Func& func, Args&&... args) {
|
||||||
HWY_ASSERT(base1->GetType() == base2->GetType());
|
HWY_DASSERT(base1->GetType() == base2->GetType());
|
||||||
|
|
||||||
#if GEMMA_ENABLE_NUQ
|
#if GEMMA_ENABLE_NUQ
|
||||||
if (base1->GetType() == Type::kNUQ) {
|
if (base1->GetType() == Type::kNUQ) {
|
||||||
|
|
@ -359,13 +370,12 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||||
template <class Func, typename... Args>
|
template <class Func, typename... Args>
|
||||||
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
||||||
Args&&... args) {
|
Args&&... args) {
|
||||||
HWY_ASSERT(base != nullptr);
|
|
||||||
if (base->GetType() == Type::kF32) {
|
if (base->GetType() == Type::kF32) {
|
||||||
return func(dynamic_cast<const MatPtrT<float>*>(base),
|
const MatPtrT<float> mat(*base);
|
||||||
std::forward<Args>(args)...);
|
return func(&mat, std::forward<Args>(args)...);
|
||||||
} else if (base->GetType() == Type::kBF16) {
|
} else if (base->GetType() == Type::kBF16) {
|
||||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
|
const MatPtrT<BF16> mat(*base);
|
||||||
std::forward<Args>(args)...);
|
return func(&mat, std::forward<Args>(args)...);
|
||||||
} else {
|
} else {
|
||||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue