mirror of https://github.com/google/gemma.cpp.git
1.07x batch decode speedup: more BF16 weights and activations
BF16 att_sums and ffw_out Support BF16 B views without decompression Support arbitrary types in MulByConstAndAdd, AddFrom Also update profiler annotations in ops-inl.h PiperOrigin-RevId: 766995010
This commit is contained in:
parent
839a642992
commit
9efdcfd45c
|
|
@ -214,6 +214,7 @@ cc_library(
|
|||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":mat",
|
||||
":model_store",
|
||||
":ops",
|
||||
|
|
@ -223,9 +224,7 @@ cc_library(
|
|||
"//io:blob_store",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:timer",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -500,7 +499,6 @@ cc_library(
|
|||
# Placeholder for internal dep, do not remove.,
|
||||
"//io:blob_store",
|
||||
"//io",
|
||||
"//compression:types",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: env_(MakeMatMulEnv(threading)), gemma_(loader, env_) {
|
||||
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
|
||||
const ModelConfig& config = gemma_.GetModelConfig();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
|
||||
|
|
@ -229,8 +229,9 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
threading.Print(inference.verbosity);
|
||||
loader.Print(inference.verbosity);
|
||||
inference.Print(inference.verbosity);
|
||||
fprintf(stderr, "Model : %s, mmap %d\n",
|
||||
config.Specifier().c_str(), static_cast<int>(loader.map));
|
||||
fprintf(stderr, "Model : %s, to_bf16 %d, mmap %d\n",
|
||||
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
|
||||
static_cast<int>(loader.map));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class GemmaTest : public ::testing::Test {
|
|||
protected:
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
s_env->SetMaxGeneratedTokens(32);
|
||||
s_env->SetMaxGeneratedTokens(16);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 2;
|
||||
std::vector<std::string> replies;
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||
gcpp::Gemma gemma(loader, env);
|
||||
gcpp::Gemma gemma(loader, inference, env);
|
||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
size_t generated = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class SimplifiedGemma {
|
|||
threading_(threading),
|
||||
inference_(inference),
|
||||
env_(MakeMatMulEnv(threading_)),
|
||||
gemma_(loader_, env_),
|
||||
gemma_(loader_, inference_, env_),
|
||||
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
|
|
|
|||
|
|
@ -88,23 +88,14 @@ struct Activations {
|
|||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
|
||||
if (!mat.HasPtr()) return;
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
|
||||
uint8_t** ptrs = row_ptrs.back().get();
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
ptrs[r] = mat.RowBytes(r);
|
||||
}
|
||||
mat.AttachRowPtrs(ptrs);
|
||||
};
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
init_row_ptrs(q);
|
||||
init_row_ptrs(logits);
|
||||
init_row_ptrs(att_sums);
|
||||
init_row_ptrs(C1);
|
||||
init_row_ptrs(C2);
|
||||
init_row_ptrs(ffw_out);
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
// TODO: also init rows for image_tokens.
|
||||
|
||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
|
|
@ -142,7 +133,7 @@ struct Activations {
|
|||
const MatPadding pad_ = MatPadding::kOdd;
|
||||
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<float> q; // query, also KV if MHA.
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<float> logits;
|
||||
|
||||
// Attention
|
||||
|
|
@ -150,13 +141,13 @@ struct Activations {
|
|||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
// Accumulation of attention outputs over heads
|
||||
MatStorageT<float> att_sums;
|
||||
MatStorageT<BF16> att_sums;
|
||||
|
||||
// Gated FFW
|
||||
MatStorageT<BF16> pre_ffw_rms_out;
|
||||
MatStorageT<float> C1;
|
||||
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
|
||||
MatStorageT<float> C2;
|
||||
MatStorageT<float> ffw_out;
|
||||
MatStorageT<BF16> ffw_out;
|
||||
|
||||
// Griffin
|
||||
MatStorageT<float> griffin_x;
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
|||
threading_args(threading_args),
|
||||
matmul_env(MakeMatMulEnv(threading_args)),
|
||||
active_conversation_name("default"),
|
||||
model(loader, matmul_env) {
|
||||
model(loader, inference_args, matmul_env) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
|
|
|
|||
|
|
@ -210,17 +210,9 @@ static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
|||
} // interleaved_idx
|
||||
|
||||
// Final linear layer.
|
||||
// TODO: MatMul
|
||||
CallUpcasted(
|
||||
&layer_weights->griffin.linear_out_w, [&](const auto* weights_t) {
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||
float* out_ptr = activations.att_sums.Row(r);
|
||||
MatVecAdd(*weights_t, 0, model_dim, model_dim, x,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(),
|
||||
out_ptr, pool);
|
||||
}
|
||||
});
|
||||
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(),
|
||||
*activations.env, activations.att_sums);
|
||||
} // GriffinRecurrent
|
||||
|
||||
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||
|
|
@ -1142,7 +1134,10 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
}
|
||||
});
|
||||
// Add position embeddings.
|
||||
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
|
||||
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
||||
[&](const auto* weights_t) {
|
||||
AddFromBatched(*weights_t, activations.x);
|
||||
});
|
||||
}
|
||||
|
||||
// Prefills the image tokens with the ViT encoder.
|
||||
|
|
|
|||
|
|
@ -68,13 +68,14 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
|||
return MatMulEnv(ThreadingContext::Get());
|
||||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
MatMulEnv& env)
|
||||
: env_(env),
|
||||
reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_,
|
||||
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
||||
env.ctx.pools.Pool());
|
||||
reader_.CloseFile();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -105,7 +105,8 @@ class Gemma {
|
|||
public:
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const LoaderArgs& loader, MatMulEnv& env);
|
||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
MatMulEnv& env);
|
||||
|
||||
~Gemma();
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Tristate map;
|
||||
Tristate to_bf16;
|
||||
Tristate wrapping;
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -62,6 +63,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||
visitor(map, "map", Tristate::kDefault,
|
||||
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
|
||||
visitor(to_bf16, "to_bf16", Tristate::kDefault,
|
||||
"Convert weights to bf16? -1 = auto, 0 = no, 1 = yes.");
|
||||
visitor(wrapping, "wrapping", Tristate::kDefault,
|
||||
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
|
||||
MatMulEnv env(MakeMatMulEnv(threading));
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
const Gemma gemma(loader, env);
|
||||
const Gemma gemma(loader, inference, env);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
|
|
|
|||
220
gemma/weights.cc
220
gemma/weights.cc
|
|
@ -27,6 +27,7 @@
|
|||
#include "compression/compress.h"
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "ops/matmul.h" // MMParallel
|
||||
|
|
@ -37,7 +38,7 @@
|
|||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
// TODO: move into foreach_target; this is only used for NUQ Fixup.
|
||||
// TODO: move into foreach_target
|
||||
#include "compression/compress-inl.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -246,16 +247,80 @@ std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
|
|||
return serialized_mat_ptrs;
|
||||
}
|
||||
|
||||
enum class Mode {
|
||||
// Parallel I/O, decompress to BF16. Best for large batch sizes.
|
||||
kReadBF16,
|
||||
// Parallel I/O, insert row-wise padding. Safe default.
|
||||
kRead,
|
||||
// Best for large weights relative to available memory, especially for
|
||||
// frequent invocations of small batches and short sequences. Adds noise to
|
||||
// performance measurements due to I/O variability.
|
||||
kMap
|
||||
};
|
||||
|
||||
// Decides whether to read or map based on heuristics and user override.
|
||||
static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader,
|
||||
const InferenceArgs& inference) {
|
||||
Tristate to_bf16 = loader.to_bf16;
|
||||
Tristate map = loader.map;
|
||||
|
||||
// Disable mapping if not padded to the base page size.
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
if (file_bytes % allocator.BasePageBytes() != 0) {
|
||||
if (map != Tristate::kFalse) { // Do not complain if anyway disabled.
|
||||
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
||||
static_cast<size_t>(file_bytes >> 10),
|
||||
allocator.BasePageBytes());
|
||||
map = Tristate::kFalse;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for user override:
|
||||
if (to_bf16 == Tristate::kTrue && map == Tristate::kTrue) {
|
||||
HWY_WARN("Cannot have to_bf16 && map, to_bf16 takes precedence.");
|
||||
}
|
||||
if (to_bf16 == Tristate::kTrue) return Mode::kReadBF16;
|
||||
if (map == Tristate::kTrue) return Mode::kMap;
|
||||
|
||||
if (to_bf16 == Tristate::kDefault) {
|
||||
// Heuristic: sub-bf16 compression is not helpful if compute-bound.
|
||||
const size_t batch_size =
|
||||
HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size);
|
||||
to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse;
|
||||
}
|
||||
|
||||
if (map == Tristate::kDefault) {
|
||||
// Heuristic: map if large fraction of total. Do not decide based on
|
||||
// `FreeMiB` because it is generally low.
|
||||
const size_t file_mib = file_bytes >> 20;
|
||||
const size_t total_mib = allocator.TotalMiB();
|
||||
if (file_mib > total_mib) {
|
||||
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
||||
static_cast<size_t>(file_mib), total_mib);
|
||||
}
|
||||
// Large fraction of total.
|
||||
map = (file_mib >= total_mib / 3) ? Tristate::kTrue : Tristate::kFalse;
|
||||
}
|
||||
|
||||
// If the `map` heuristic triggers, use that for safety.
|
||||
if (map == Tristate::kTrue) return Mode::kMap;
|
||||
return (to_bf16 == Tristate::kTrue) ? Mode::kReadBF16 : Mode::kRead;
|
||||
}
|
||||
|
||||
struct TensorToRead {
|
||||
MatPtr* mat;
|
||||
BlobRange range;
|
||||
// Some tensors opt out of padding via kPacked flags.
|
||||
MatPadding padding;
|
||||
|
||||
// only for kReadBF16
|
||||
bool keep_type = false;
|
||||
Type prev_type;
|
||||
};
|
||||
|
||||
// Allocates multiple in parallel and binds to NUMA nodes.
|
||||
static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
|
||||
std::vector<MatOwner>& owners,
|
||||
static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||
const Mode mode, std::vector<MatOwner>& owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
const size_t start = owners.size();
|
||||
owners.resize(start + tensors.size());
|
||||
|
|
@ -264,52 +329,25 @@ static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
|
|||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
|
||||
TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
tensor.prev_type = mat.GetType();
|
||||
// We only care about MatMul inputs; skip F32 or small tensors.
|
||||
if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) {
|
||||
tensor.keep_type = true;
|
||||
tensor.padding = MatPadding::kPacked; // single I/O for simplicity
|
||||
} else if (mode == Mode::kReadBF16) {
|
||||
mat.SetType(Type::kBF16);
|
||||
}
|
||||
|
||||
owners[start + task].AllocateFor(*tensor.mat, tensor.padding);
|
||||
// TODO(janwas): MatMul outputs will later also be BF16.
|
||||
BindB(*tensors[task].mat, sizeof(float), parallel);
|
||||
BindB(*tensor.mat, sizeof(float), parallel);
|
||||
});
|
||||
}
|
||||
|
||||
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
||||
// better when the file is huge, but page faults add noise to measurements.
|
||||
enum class Mode { kRead, kMap };
|
||||
|
||||
// Decides whether to read or map based on heuristics and user override.
|
||||
static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
// User has explicitly requested a map or read via args.
|
||||
if (map == Tristate::kTrue) return Mode::kMap;
|
||||
if (map == Tristate::kFalse) return Mode::kRead;
|
||||
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
|
||||
// because idle memory is used as cache, so do not use it to decide.
|
||||
const size_t file_mib = file_bytes >> 20;
|
||||
const size_t total_mib = allocator.TotalMiB();
|
||||
if (file_mib > total_mib) {
|
||||
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
||||
static_cast<size_t>(file_mib), total_mib);
|
||||
}
|
||||
// Large fraction of total.
|
||||
if (file_mib >= total_mib / 3) return Mode::kMap;
|
||||
// Big enough that even parallel loading wouldn't be quick.
|
||||
if (file_mib > 50 * 1024) return Mode::kMap;
|
||||
return Mode::kRead;
|
||||
}
|
||||
|
||||
static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
if (file_bytes % allocator.BasePageBytes() == 0) {
|
||||
MapPtr mapped = file.Map();
|
||||
if (!mapped) {
|
||||
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
||||
static_cast<size_t>(file_bytes >> 10));
|
||||
}
|
||||
} else {
|
||||
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
||||
static_cast<size_t>(file_bytes >> 10), allocator.BasePageBytes());
|
||||
}
|
||||
return MapPtr();
|
||||
}
|
||||
|
||||
// Mode == kMap
|
||||
static void MapAll(const std::vector<TensorToRead>& tensors,
|
||||
const MapPtr& mapped) {
|
||||
PROFILER_ZONE("Startup.Weights.Map");
|
||||
|
|
@ -326,6 +364,65 @@ static void MapAll(const std::vector<TensorToRead>& tensors,
|
|||
}
|
||||
}
|
||||
|
||||
// Mode == kReadBF16:
|
||||
|
||||
template <typename T>
|
||||
static void DecompressToBF16(const MatPtr& mat,
|
||||
const hwy::AlignedFreeUniquePtr<uint8_t[]>& buf) {
|
||||
hwy::HWY_NAMESPACE::ScalableTag<BF16> dbf;
|
||||
const size_t cols = mat.Cols();
|
||||
|
||||
const size_t num_packed = CompressedArrayElements<T>(mat.Extents().Area());
|
||||
const PackedSpan<T> packed{HWY_RCAST_ALIGNED(T*, buf.get()), num_packed};
|
||||
|
||||
size_t packed_ofs = 0;
|
||||
for (size_t r = 0; r < mat.Rows(); ++r, packed_ofs += cols) {
|
||||
HWY_NAMESPACE::DecompressAndZeroPad(
|
||||
dbf, packed, packed_ofs, HWY_RCAST_ALIGNED(BF16*, mat.RowBytes(r)),
|
||||
cols);
|
||||
}
|
||||
}
|
||||
|
||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.Weights.ReadBF16");
|
||||
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
if (tensor.keep_type) {
|
||||
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
|
||||
mat.Packed()));
|
||||
return;
|
||||
}
|
||||
|
||||
// Read to a temporary buffer.
|
||||
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
|
||||
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
|
||||
HWY_ASSERT(
|
||||
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
|
||||
|
||||
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||
if (tensor.prev_type == Type::kNUQ) {
|
||||
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
|
||||
}
|
||||
}
|
||||
switch (tensor.prev_type) {
|
||||
case Type::kF32:
|
||||
return DecompressToBF16<float>(*tensor.mat, buf);
|
||||
case Type::kBF16:
|
||||
return DecompressToBF16<BF16>(*tensor.mat, buf);
|
||||
case Type::kSFP:
|
||||
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
|
||||
default:
|
||||
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Mode == kRead:
|
||||
|
||||
static std::vector<IOBatch> MakeBatches(
|
||||
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
|
||||
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||
|
|
@ -382,31 +479,36 @@ static void ReadBatches(const BlobReader& reader,
|
|||
}
|
||||
|
||||
// Aborts on error.
|
||||
static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
|
||||
BlobReader& reader, Tristate map,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
||||
Mode mode, std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
||||
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
|
||||
if (mapped) {
|
||||
MapAll(tensors, mapped);
|
||||
return;
|
||||
}
|
||||
} // otherwise fall through to read mode
|
||||
if (mode == Mode::kMap) {
|
||||
MapPtr mapped = reader.file().Map();
|
||||
if (mapped) return MapAll(tensors, mapped);
|
||||
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
||||
static_cast<size_t>(reader.file_bytes() >> 10));
|
||||
// If we wanted to map but failed, memory is probably not plentiful, so
|
||||
// fall through to kRead because kReadBF16 requires more memory.
|
||||
mode = Mode::kRead;
|
||||
}
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||
// NOTE: this changes the stride of `mats`!
|
||||
AllocateAndBindAll(tensors, mat_owners, pool);
|
||||
AllocateAndBindAll(tensors, mode, mat_owners, pool);
|
||||
}
|
||||
|
||||
if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool);
|
||||
|
||||
const std::vector<IOBatch> batches =
|
||||
MakeBatches(tensors, reader.file_bytes());
|
||||
ReadBatches(reader, batches, pool);
|
||||
}
|
||||
|
||||
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
||||
BlobReader& reader, Tristate map,
|
||||
BlobReader& reader,
|
||||
const LoaderArgs& loader,
|
||||
const InferenceArgs& inference,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
// List of tensors to read/map, and where from.
|
||||
|
|
@ -427,7 +529,9 @@ void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
|||
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
||||
});
|
||||
|
||||
MapOrReadAll(tensors, reader, map, mat_owners, pool);
|
||||
const Mode mode = ChooseMode(reader.file_bytes(), loader, inference);
|
||||
|
||||
MapOrReadAll(tensors, reader, mode, mat_owners, pool);
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Startup.Fixup");
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // InferenceArgs
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
||||
#include "io/blob_store.h" // BlobWriter
|
||||
|
|
@ -424,7 +425,8 @@ struct ModelWeightsPtrs {
|
|||
|
||||
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
||||
// override for whether to map blobs or read them.
|
||||
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
|
||||
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||
const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
||||
|
||||
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
||||
|
|
|
|||
|
|
@ -73,7 +73,6 @@ cc_library(
|
|||
"//:basics",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,6 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#pragma push_macro("PROFILER_ENABLED")
|
||||
#undef PROFILER_ENABLED
|
||||
#define PROFILER_ENABLED 0
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -952,15 +948,11 @@ class MMPerPackage {
|
|||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||
const RowPtrBF B_storage_view(B_storage, K, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT.DecB", args_);
|
||||
DecompressB(B, row_b, range_K, B_view);
|
||||
}
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
}
|
||||
|
|
@ -985,17 +977,13 @@ class MMPerPackage {
|
|||
auto out_tag) HWY_ATTR {
|
||||
const size_t kc = range_kc.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const RowPtrBF B_view(
|
||||
const RowPtrBF B_storage_view(
|
||||
B_storage, kc,
|
||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
||||
DecompressB(B, row_b, range_kc, B_view);
|
||||
}
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
}
|
||||
|
|
@ -1051,15 +1039,11 @@ class MMPerPackage {
|
|||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||
const RowPtrBF B_storage_view(B_storage, K, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT.DecB", args_);
|
||||
DecompressB(B, row_b, range_K, B_view);
|
||||
}
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
}
|
||||
|
|
@ -1081,7 +1065,8 @@ class MMPerPackage {
|
|||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||
const auto loop_nc = [&](const RowPtrBF& B_view, const IndexRange& range_mc,
|
||||
const auto loop_nc = [&](const RowPtrBF& B_storage_view,
|
||||
const IndexRange& range_mc,
|
||||
const IndexRange& range_kc,
|
||||
const IndexRange& range_nc,
|
||||
auto out_tag) HWY_ATTR {
|
||||
|
|
@ -1090,11 +1075,7 @@ class MMPerPackage {
|
|||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
||||
DecompressB(B, row_b, range_kc, B_view);
|
||||
}
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
}
|
||||
|
|
@ -1103,15 +1084,17 @@ class MMPerPackage {
|
|||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(B_storage, kc_max, B_stride);
|
||||
const RowPtrBF B_storage_view(B_storage, kc_max, B_stride);
|
||||
|
||||
// Peel off the first iteration of the kc loop: avoid
|
||||
// zero-initializing `partial` by writing into it.
|
||||
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
|
||||
loop_nc(B_view, range_mc, range_kc, range_nc, MMSetPartial());
|
||||
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
|
||||
MMSetPartial());
|
||||
});
|
||||
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
|
||||
loop_nc(B_view, range_mc, range_kc, range_nc, MMAddPartial());
|
||||
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
|
||||
MMAddPartial());
|
||||
});
|
||||
|
||||
// Already in parallel section, hence no `kParM`, and
|
||||
|
|
@ -1228,11 +1211,15 @@ class MMPerPackage {
|
|||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||
// thanks to its large table lookups, and less so on other targets.
|
||||
template <typename TB>
|
||||
HWY_INLINE void DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const RowPtrBF& B_view) const {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
HWY_INLINE RowPtrBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const RowPtrBF& B_view) const {
|
||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||
return RowPtrBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
|
||||
range_kc.Num(), B.Stride());
|
||||
}
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||
|
||||
const size_t kc = range_kc.Num();
|
||||
|
|
@ -1249,6 +1236,7 @@ class MMPerPackage {
|
|||
}
|
||||
}
|
||||
}
|
||||
return B_view;
|
||||
}
|
||||
|
||||
const MMArgs args_; // copy for locality
|
||||
|
|
@ -1421,5 +1409,3 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
||||
#pragma pop_macro("PROFILER_ENABLED")
|
||||
|
|
|
|||
12
ops/matmul.h
12
ops/matmul.h
|
|
@ -567,6 +567,16 @@ class MMAutoTune {
|
|||
|
||||
// Map of previously seen dimensions to index via linear search.
|
||||
class MMKeys {
|
||||
// Group batch size into buckets to reduce #auto-tunes.
|
||||
static size_t BucketM(size_t M) {
|
||||
// The first 4 may require their own bucket because `kNT` only works for a
|
||||
// single M range, but that depends on the config's `MR()`.
|
||||
if (M <= 4) return M;
|
||||
if (M <= 16) return 16;
|
||||
if (M <= 64) return 64;
|
||||
return 256;
|
||||
}
|
||||
|
||||
public:
|
||||
using Key = uint64_t;
|
||||
// KeyFromDims will only return this if all dims are zero, which is invalid.
|
||||
|
|
@ -577,7 +587,7 @@ class MMKeys {
|
|||
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
||||
HWY_DASSERT(K < (Key{1} << 24));
|
||||
HWY_DASSERT(N < (Key{1} << 24));
|
||||
const Key key = static_cast<Key>(M) | (static_cast<Key>(K) << 16) |
|
||||
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
|
||||
(static_cast<Key>(N) << 40);
|
||||
HWY_DASSERT(key != kPadding);
|
||||
return key;
|
||||
|
|
|
|||
127
ops/ops-inl.h
127
ops/ops-inl.h
|
|
@ -117,6 +117,7 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
|||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
PROFILER_ZONE("ops.Gelu");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
hn::Transform(D(), x, size,
|
||||
|
|
@ -181,6 +182,8 @@ namespace detail {
|
|||
// Shared by RMSNorm and RMSNormInplace.
|
||||
template <typename VT>
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
||||
PROFILER_ZONE("ops.RMSNormMul");
|
||||
|
||||
const hn::ScalableTag<float> d;
|
||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||
|
|
@ -195,7 +198,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
|||
const WT* HWY_RESTRICT weight,
|
||||
size_t w_ofs, OT* HWY_RESTRICT out,
|
||||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
PROFILER_ZONE("ops.RMSNorm");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -228,7 +231,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
|||
size_t w_ofs,
|
||||
XT* HWY_RESTRICT inout,
|
||||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
PROFILER_ZONE("ops.RMSNormInplace");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -280,7 +283,7 @@ HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size,
|
|||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
||||
const WT* HWY_RESTRICT bias, OT* out, size_t size) {
|
||||
PROFILER_FUNC;
|
||||
PROFILER_ZONE("ops.LayerNorm");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -360,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
}
|
||||
}
|
||||
|
||||
/* RoPE as in Rotary Position Embeddings from the RoFormer paper
|
||||
/* RoPE as in Rotary Position Embeddings from the `RoFormer` paper
|
||||
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
|
||||
as a function of their absolute position using the rotation matrix R before
|
||||
the self-attention operation. R is a d x d matrix.
|
||||
|
|
@ -391,7 +394,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
PROFILER_FUNC;
|
||||
PROFILER_ZONE("ops.Rope");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
|
|
@ -409,7 +412,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
PROFILER_FUNC;
|
||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -477,15 +480,45 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
template <typename XT>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||
float* HWY_RESTRICT out,
|
||||
const size_t size) {
|
||||
PROFILER_ZONE("ops.AddFrom");
|
||||
|
||||
hn::Transform1(D(), x, size, other,
|
||||
[](const auto d, const V x, const V other)
|
||||
HWY_ATTR { return hn::Add(x, other); });
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
|
||||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||
VF x0, x1;
|
||||
Decompress2(df, packed_x, i, x0, x1);
|
||||
VF out0 = hn::Load(df, out + i);
|
||||
VF out1 = hn::Load(df, out + i + NF);
|
||||
hn::Store(hn::Add(x0, out0), df, out + i);
|
||||
hn::Store(hn::Add(x1, out1), df, out + i + NF);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = size - i;
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, NF);
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
HWY_DASSERT(remaining1 < NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
const VF x0 = hn::Load(df, buf_x);
|
||||
const VF x1 = hn::Load(df, buf_x + NF);
|
||||
const VF out0 = hn::LoadN(df, out + i, remaining);
|
||||
const VF out1 = hn::LoadN(df, out + i + NF, remaining1);
|
||||
hn::StoreN(hn::Add(x0, out0), df, out + i, remaining);
|
||||
hn::StoreN(hn::Add(x1, out1), df, out + i + NF, remaining1);
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||
|
|
@ -534,17 +567,19 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
|
|||
});
|
||||
}
|
||||
|
||||
static HWY_INLINE void AddFromBatched(const MatPtrT<float>& other,
|
||||
MatPtrT<float>& x) {
|
||||
HWY_DASSERT(x.SameShape(other));
|
||||
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
|
||||
AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols());
|
||||
template <typename XT>
|
||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
|
||||
MatPtrT<float>& out) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols());
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t max_pos) {
|
||||
PROFILER_ZONE("ops.MulBy");
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -563,6 +598,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
|||
|
||||
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
||||
const size_t size, const size_t max_pos) {
|
||||
PROFILER_ZONE("ops.MulByConst");
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -578,22 +614,56 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
|||
MulByConst(c, x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
|
||||
size_t size) {
|
||||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
||||
const XT* HWY_RESTRICT x,
|
||||
OT* HWY_RESTRICT out,
|
||||
size_t size) {
|
||||
PROFILER_ZONE("ops.MulByConstAndAdd");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
hn::Transform1(D(), out, size, x,
|
||||
[c](const auto d, const V v_out, const V v_x) HWY_ATTR {
|
||||
return hn::MulAdd(v_x, hn::Set(d, c), v_out);
|
||||
});
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
||||
const VF v_c = hn::Set(df, c);
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
const auto packed_out = MakeSpan(out, size);
|
||||
|
||||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||
VF x0, x1, out0, out1;
|
||||
Decompress2(df, packed_x, i, x0, x1);
|
||||
Decompress2(df, packed_out, i, out0, out1);
|
||||
out0 = hn::MulAdd(x0, v_c, out0);
|
||||
out1 = hn::MulAdd(x1, v_c, out1);
|
||||
Compress2(df, out0, out1, packed_out, i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = size - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
|
||||
const VF x0 = hn::Load(df, buf_x);
|
||||
const VF x1 = hn::Load(df, buf_x + NF);
|
||||
VF out0 = hn::Load(df, buf_out);
|
||||
VF out1 = hn::Load(df, buf_out + NF);
|
||||
out0 = hn::MulAdd(x0, v_c, out0);
|
||||
out1 = hn::MulAdd(x1, v_c, out1);
|
||||
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
|
||||
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
|
||||
}
|
||||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t mask_pos,
|
||||
float temperature = 1.0f) {
|
||||
PROFILER_ZONE("ops.Softmax");
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
|
||||
|
|
@ -733,6 +803,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
|||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
const size_t max_pos) {
|
||||
PROFILER_ZONE("ops.LogitsSoftCap");
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
|
|||
22
util/mat.h
22
util/mat.h
|
|
@ -96,6 +96,17 @@ class MatPtr : public IFields {
|
|||
}
|
||||
}
|
||||
|
||||
void AllocateAndAttachRowPtrs(
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) {
|
||||
if (!HasPtr()) return;
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(Rows()));
|
||||
uint8_t** ptrs = row_ptrs.back().get();
|
||||
for (size_t r = 0; r < Rows(); ++r) {
|
||||
ptrs[r] = RowBytes(r);
|
||||
}
|
||||
AttachRowPtrs(ptrs);
|
||||
};
|
||||
|
||||
uint8_t** GetRowPtrs() const { return row_ptrs_; }
|
||||
|
||||
// A single row counts as packed because there is no padding between rows.
|
||||
|
|
@ -328,7 +339,7 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
|||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||
const Func& func, Args&&... args) {
|
||||
HWY_ASSERT(base1->GetType() == base2->GetType());
|
||||
HWY_DASSERT(base1->GetType() == base2->GetType());
|
||||
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
if (base1->GetType() == Type::kNUQ) {
|
||||
|
|
@ -359,13 +370,12 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
|||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
HWY_ASSERT(base != nullptr);
|
||||
if (base->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<float> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kBF16) {
|
||||
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
const MatPtrT<BF16> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue