1.07x batch decode speedup: more BF16 weights and activations

BF16 att_sums and ffw_out
Support BF16 B views without decompression
Support arbitrary types in MulByConstAndAdd, AddFrom

Also update profiler annotations in ops-inl.h

PiperOrigin-RevId: 766995010
This commit is contained in:
Jan Wassenberg 2025-06-03 23:28:57 -07:00 committed by Copybara-Service
parent 839a642992
commit 9efdcfd45c
19 changed files with 349 additions and 177 deletions

View File

@ -214,6 +214,7 @@ cc_library(
hdrs = ["gemma/weights.h"],
deps = [
":configs",
":gemma_args",
":mat",
":model_store",
":ops",
@ -223,9 +224,7 @@ cc_library(
"//io:blob_store",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//:timer",
],
)
@ -500,7 +499,6 @@ cc_library(
# Placeholder for internal dep, do not remove.,
"//io:blob_store",
"//io",
"//compression:types",
"//paligemma:image",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer

View File

@ -49,7 +49,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading)), gemma_(loader, env_) {
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
@ -229,8 +229,9 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
threading.Print(inference.verbosity);
loader.Print(inference.verbosity);
inference.Print(inference.verbosity);
fprintf(stderr, "Model : %s, mmap %d\n",
config.Specifier().c_str(), static_cast<int>(loader.map));
fprintf(stderr, "Model : %s, to_bf16 %d, mmap %d\n",
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
static_cast<int>(loader.map));
if (inference.verbosity >= 2) {
time_t now = time(nullptr);

View File

@ -37,7 +37,7 @@ class GemmaTest : public ::testing::Test {
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
s_env->SetMaxGeneratedTokens(32);
s_env->SetMaxGeneratedTokens(16);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 2;
std::vector<std::string> replies;

View File

@ -52,7 +52,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
gcpp::Gemma gemma(loader, env);
gcpp::Gemma gemma(loader, inference, env);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
size_t generated = 0;

View File

@ -39,7 +39,7 @@ class SimplifiedGemma {
threading_(threading),
inference_(inference),
env_(MakeMatMulEnv(threading_)),
gemma_(loader_, env_),
gemma_(loader_, inference_, env_),
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
// Initialize random number generator
std::random_device rd;

View File

@ -88,23 +88,14 @@ struct Activations {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
if (!mat.HasPtr()) return;
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
uint8_t** ptrs = row_ptrs.back().get();
for (size_t r = 0; r < mat.Rows(); ++r) {
ptrs[r] = mat.RowBytes(r);
}
mat.AttachRowPtrs(ptrs);
};
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
init_row_ptrs(q);
init_row_ptrs(logits);
init_row_ptrs(att_sums);
init_row_ptrs(C1);
init_row_ptrs(C2);
init_row_ptrs(ffw_out);
q.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
// TODO: also init rows for image_tokens.
// Note that BindC on any MatMul output considerably slows down Prefill.
@ -142,7 +133,7 @@ struct Activations {
const MatPadding pad_ = MatPadding::kOdd;
MatStorageT<float> x; // input
MatStorageT<float> q; // query, also KV if MHA.
MatStorageT<float> q; // query
MatStorageT<float> logits;
// Attention
@ -150,13 +141,13 @@ struct Activations {
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
MatStorageT<float> att_sums;
MatStorageT<BF16> att_sums;
// Gated FFW
MatStorageT<BF16> pre_ffw_rms_out;
MatStorageT<float> C1;
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
MatStorageT<float> C2;
MatStorageT<float> ffw_out;
MatStorageT<BF16> ffw_out;
// Griffin
MatStorageT<float> griffin_x;

View File

@ -109,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args)),
active_conversation_name("default"),
model(loader, matmul_env) {
model(loader, inference_args, matmul_env) {
std::stringstream ss;
LogDebug("Creating initial ConversationData");

View File

@ -210,17 +210,9 @@ static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
} // interleaved_idx
// Final linear layer.
// TODO: MatMul
CallUpcasted(
&layer_weights->griffin.linear_out_w, [&](const auto* weights_t) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
float* out_ptr = activations.att_sums.Row(r);
MatVecAdd(*weights_t, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(),
out_ptr, pool);
}
});
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(),
*activations.env, activations.att_sums);
} // GriffinRecurrent
// Wrapper class; holds arguments in member variables to shorten call sites.
@ -1142,7 +1134,10 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
}
});
// Add position embeddings.
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
CallUpcastedActivation(&weights.vit_img_pos_embedding,
[&](const auto* weights_t) {
AddFromBatched(*weights_t, activations.x);
});
}
// Prefills the image tokens with the ViT encoder.

View File

@ -68,13 +68,14 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
return MatMulEnv(ThreadingContext::Get());
}
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
MatMulEnv& env)
: env_(env),
reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model) {
weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_,
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
env.ctx.pools.Pool());
reader_.CloseFile();
}

View File

@ -105,7 +105,8 @@ class Gemma {
public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `env` must remain valid for the lifetime of this Gemma.
Gemma(const LoaderArgs& loader, MatMulEnv& env);
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
MatMulEnv& env);
~Gemma();

View File

@ -52,6 +52,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path tokenizer;
Path weights; // weights file location
Tristate map;
Tristate to_bf16;
Tristate wrapping;
template <class Visitor>
@ -62,6 +63,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"Path name of model weights (.sbs) file.\n Required argument.\n");
visitor(map, "map", Tristate::kDefault,
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
visitor(to_bf16, "to_bf16", Tristate::kDefault,
"Convert weights to bf16? -1 = auto, 0 = no, 1 = yes.");
visitor(wrapping, "wrapping", Tristate::kDefault,
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
}

View File

@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
MatMulEnv env(MakeMatMulEnv(threading));
if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, env);
const Gemma gemma(loader, inference, env);
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
if (inference.verbosity >= 1) {

View File

@ -27,6 +27,7 @@
#include "compression/compress.h"
#include "compression/types.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "gemma/model_store.h"
#include "io/blob_store.h"
#include "ops/matmul.h" // MMParallel
@ -37,7 +38,7 @@
#include "hwy/highway.h"
#include "hwy/profiler.h"
// TODO: move into foreach_target; this is only used for NUQ Fixup.
// TODO: move into foreach_target
#include "compression/compress-inl.h"
namespace gcpp {
@ -246,16 +247,80 @@ std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
return serialized_mat_ptrs;
}
enum class Mode {
// Parallel I/O, decompress to BF16. Best for large batch sizes.
kReadBF16,
// Parallel I/O, insert row-wise padding. Safe default.
kRead,
// Best for large weights relative to available memory, especially for
// frequent invocations of small batches and short sequences. Adds noise to
// performance measurements due to I/O variability.
kMap
};
// Decides whether to read or map based on heuristics and user override.
static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader,
const InferenceArgs& inference) {
Tristate to_bf16 = loader.to_bf16;
Tristate map = loader.map;
// Disable mapping if not padded to the base page size.
const Allocator& allocator = ThreadingContext::Get().allocator;
if (file_bytes % allocator.BasePageBytes() != 0) {
if (map != Tristate::kFalse) { // Do not complain if anyway disabled.
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
static_cast<size_t>(file_bytes >> 10),
allocator.BasePageBytes());
map = Tristate::kFalse;
}
}
// Check for user override:
if (to_bf16 == Tristate::kTrue && map == Tristate::kTrue) {
HWY_WARN("Cannot have to_bf16 && map, to_bf16 takes precedence.");
}
if (to_bf16 == Tristate::kTrue) return Mode::kReadBF16;
if (map == Tristate::kTrue) return Mode::kMap;
if (to_bf16 == Tristate::kDefault) {
// Heuristic: sub-bf16 compression is not helpful if compute-bound.
const size_t batch_size =
HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size);
to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse;
}
if (map == Tristate::kDefault) {
// Heuristic: map if large fraction of total. Do not decide based on
// `FreeMiB` because it is generally low.
const size_t file_mib = file_bytes >> 20;
const size_t total_mib = allocator.TotalMiB();
if (file_mib > total_mib) {
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
static_cast<size_t>(file_mib), total_mib);
}
// Large fraction of total.
map = (file_mib >= total_mib / 3) ? Tristate::kTrue : Tristate::kFalse;
}
// If the `map` heuristic triggers, use that for safety.
if (map == Tristate::kTrue) return Mode::kMap;
return (to_bf16 == Tristate::kTrue) ? Mode::kReadBF16 : Mode::kRead;
}
struct TensorToRead {
MatPtr* mat;
BlobRange range;
// Some tensors opt out of padding via kPacked flags.
MatPadding padding;
// only for kReadBF16
bool keep_type = false;
Type prev_type;
};
// Allocates multiple in parallel and binds to NUMA nodes.
static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
std::vector<MatOwner>& owners,
static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
const Mode mode, std::vector<MatOwner>& owners,
hwy::ThreadPool& pool) {
const size_t start = owners.size();
owners.resize(start + tensors.size());
@ -264,52 +329,25 @@ static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
// Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
tensor.prev_type = mat.GetType();
// We only care about MatMul inputs; skip F32 or small tensors.
if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) {
tensor.keep_type = true;
tensor.padding = MatPadding::kPacked; // single I/O for simplicity
} else if (mode == Mode::kReadBF16) {
mat.SetType(Type::kBF16);
}
owners[start + task].AllocateFor(*tensor.mat, tensor.padding);
// TODO(janwas): MatMul outputs will later also be BF16.
BindB(*tensors[task].mat, sizeof(float), parallel);
BindB(*tensor.mat, sizeof(float), parallel);
});
}
// Parallel I/O into allocated memory, or mapped view of file. The latter is
// better when the file is huge, but page faults add noise to measurements.
enum class Mode { kRead, kMap };
// Decides whether to read or map based on heuristics and user override.
static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
const Allocator& allocator = ThreadingContext::Get().allocator;
// User has explicitly requested a map or read via args.
if (map == Tristate::kTrue) return Mode::kMap;
if (map == Tristate::kFalse) return Mode::kRead;
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
// because idle memory is used as cache, so do not use it to decide.
const size_t file_mib = file_bytes >> 20;
const size_t total_mib = allocator.TotalMiB();
if (file_mib > total_mib) {
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
static_cast<size_t>(file_mib), total_mib);
}
// Large fraction of total.
if (file_mib >= total_mib / 3) return Mode::kMap;
// Big enough that even parallel loading wouldn't be quick.
if (file_mib > 50 * 1024) return Mode::kMap;
return Mode::kRead;
}
static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
const Allocator& allocator = ThreadingContext::Get().allocator;
if (file_bytes % allocator.BasePageBytes() == 0) {
MapPtr mapped = file.Map();
if (!mapped) {
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
static_cast<size_t>(file_bytes >> 10));
}
} else {
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
static_cast<size_t>(file_bytes >> 10), allocator.BasePageBytes());
}
return MapPtr();
}
// Mode == kMap
static void MapAll(const std::vector<TensorToRead>& tensors,
const MapPtr& mapped) {
PROFILER_ZONE("Startup.Weights.Map");
@ -326,6 +364,65 @@ static void MapAll(const std::vector<TensorToRead>& tensors,
}
}
// Mode == kReadBF16:
template <typename T>
static void DecompressToBF16(const MatPtr& mat,
const hwy::AlignedFreeUniquePtr<uint8_t[]>& buf) {
hwy::HWY_NAMESPACE::ScalableTag<BF16> dbf;
const size_t cols = mat.Cols();
const size_t num_packed = CompressedArrayElements<T>(mat.Extents().Area());
const PackedSpan<T> packed{HWY_RCAST_ALIGNED(T*, buf.get()), num_packed};
size_t packed_ofs = 0;
for (size_t r = 0; r < mat.Rows(); ++r, packed_ofs += cols) {
HWY_NAMESPACE::DecompressAndZeroPad(
dbf, packed, packed_ofs, HWY_RCAST_ALIGNED(BF16*, mat.RowBytes(r)),
cols);
}
}
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Weights.ReadBF16");
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
if (tensor.keep_type) {
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
mat.Packed()));
return;
}
// Read to a temporary buffer.
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
HWY_ASSERT(
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
if constexpr (GEMMA_ENABLE_NUQ) {
if (tensor.prev_type == Type::kNUQ) {
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
}
}
switch (tensor.prev_type) {
case Type::kF32:
return DecompressToBF16<float>(*tensor.mat, buf);
case Type::kBF16:
return DecompressToBF16<BF16>(*tensor.mat, buf);
case Type::kSFP:
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
default:
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
}
});
}
// Mode == kRead:
static std::vector<IOBatch> MakeBatches(
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches");
@ -382,31 +479,36 @@ static void ReadBatches(const BlobReader& reader,
}
// Aborts on error.
static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
BlobReader& reader, Tristate map,
std::vector<MatOwner>& mat_owners,
static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
Mode mode, std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
if (mapped) {
MapAll(tensors, mapped);
return;
}
} // otherwise fall through to read mode
if (mode == Mode::kMap) {
MapPtr mapped = reader.file().Map();
if (mapped) return MapAll(tensors, mapped);
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
static_cast<size_t>(reader.file_bytes() >> 10));
// If we wanted to map but failed, memory is probably not plentiful, so
// fall through to kRead because kReadBF16 requires more memory.
mode = Mode::kRead;
}
{
PROFILER_ZONE("Startup.Weights.Allocate");
// NOTE: this changes the stride of `mats`!
AllocateAndBindAll(tensors, mat_owners, pool);
AllocateAndBindAll(tensors, mode, mat_owners, pool);
}
if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool);
const std::vector<IOBatch> batches =
MakeBatches(tensors, reader.file_bytes());
ReadBatches(reader, batches, pool);
}
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
BlobReader& reader, Tristate map,
BlobReader& reader,
const LoaderArgs& loader,
const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
@ -427,7 +529,9 @@ void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
});
MapOrReadAll(tensors, reader, map, mat_owners, pool);
const Mode mode = ChooseMode(reader.file_bytes(), loader, inference);
MapOrReadAll(tensors, reader, mode, mat_owners, pool);
{
PROFILER_ZONE("Startup.Fixup");

View File

@ -24,6 +24,7 @@
#include "compression/types.h"
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
#include "gemma/model_store.h" // ModelStore
#include "gemma/tensor_info.h" // TensorInfoRegistry
#include "io/blob_store.h" // BlobWriter
@ -424,7 +425,8 @@ struct ModelWeightsPtrs {
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
// override for whether to map blobs or read them.
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
const LoaderArgs& loader, const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
// Adds one blob for each tensor's data and returns all serialized MatPtr.

View File

@ -73,7 +73,6 @@ cc_library(
"//:basics",
"//:threading_context",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

View File

@ -19,10 +19,6 @@
#include <vector>
#pragma push_macro("PROFILER_ENABLED")
#undef PROFILER_ENABLED
#define PROFILER_ENABLED 0
#include "compression/types.h"
#include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h"
@ -952,15 +948,11 @@ class MMPerPackage {
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, K, B_stride);
const RowPtrBF B_storage_view(B_storage, K, B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
{
MMZone zone;
zone.MaybeEnter("MM.NT.DecB", args_);
DecompressB(B, row_b, range_K, B_view);
}
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_, C_rows);
}
@ -985,17 +977,13 @@ class MMPerPackage {
auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num();
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
const RowPtrBF B_view(
const RowPtrBF B_storage_view(
B_storage, kc,
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
{
MMZone zone;
zone.MaybeEnter("MM.NT_K.DecB", args_);
DecompressB(B, row_b, range_kc, B_view);
}
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
}
@ -1051,15 +1039,11 @@ class MMPerPackage {
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, K, B_stride);
const RowPtrBF B_storage_view(B_storage, K, B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
{
MMZone zone;
zone.MaybeEnter("MM.NT_MT.DecB", args_);
DecompressB(B, row_b, range_K, B_view);
}
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_, C_rows);
}
@ -1081,7 +1065,8 @@ class MMPerPackage {
// Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
const auto loop_nc = [&](const RowPtrBF& B_view, const IndexRange& range_mc,
const auto loop_nc = [&](const RowPtrBF& B_storage_view,
const IndexRange& range_mc,
const IndexRange& range_kc,
const IndexRange& range_nc,
auto out_tag) HWY_ATTR {
@ -1090,11 +1075,7 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
{
MMZone zone;
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
DecompressB(B, row_b, range_kc, B_view);
}
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
}
@ -1103,15 +1084,17 @@ class MMPerPackage {
ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, kc_max, B_stride);
const RowPtrBF B_storage_view(B_storage, kc_max, B_stride);
// Peel off the first iteration of the kc loop: avoid
// zero-initializing `partial` by writing into it.
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
loop_nc(B_view, range_mc, range_kc, range_nc, MMSetPartial());
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
MMSetPartial());
});
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
loop_nc(B_view, range_mc, range_kc, range_nc, MMAddPartial());
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
MMAddPartial());
});
// Already in parallel section, hence no `kParM`, and
@ -1228,11 +1211,15 @@ class MMPerPackage {
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
// thanks to its large table lookups, and less so on other targets.
template <typename TB>
HWY_INLINE void DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const RowPtrBF& B_view) const {
const hn::ScalableTag<BF16> dbf;
HWY_INLINE RowPtrBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const RowPtrBF& B_view) const {
if constexpr (hwy::IsSame<TB, BF16>()) {
return RowPtrBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
range_kc.Num(), B.Stride());
}
const hn::ScalableTag<BF16> dbf;
const PackedSpan<const TB> B_span = B.PaddedSpan();
const size_t kc = range_kc.Num();
@ -1249,6 +1236,7 @@ class MMPerPackage {
}
}
}
return B_view;
}
const MMArgs args_; // copy for locality
@ -1421,5 +1409,3 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_AFTER_NAMESPACE();
#endif // NOLINT
#pragma pop_macro("PROFILER_ENABLED")

View File

@ -567,6 +567,16 @@ class MMAutoTune {
// Map of previously seen dimensions to index via linear search.
class MMKeys {
// Group batch size into buckets to reduce #auto-tunes.
static size_t BucketM(size_t M) {
// The first 4 may require their own bucket because `kNT` only works for a
// single M range, but that depends on the config's `MR()`.
if (M <= 4) return M;
if (M <= 16) return 16;
if (M <= 64) return 64;
return 256;
}
public:
using Key = uint64_t;
// KeyFromDims will only return this if all dims are zero, which is invalid.
@ -577,7 +587,7 @@ class MMKeys {
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
HWY_DASSERT(K < (Key{1} << 24));
HWY_DASSERT(N < (Key{1} << 24));
const Key key = static_cast<Key>(M) | (static_cast<Key>(K) << 16) |
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
(static_cast<Key>(N) << 40);
HWY_DASSERT(key != kPadding);
return key;

View File

@ -117,6 +117,7 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
size_t size) {
PROFILER_ZONE("ops.Gelu");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
hn::Transform(D(), x, size,
@ -181,6 +182,8 @@ namespace detail {
// Shared by RMSNorm and RMSNormInplace.
template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
PROFILER_ZONE("ops.RMSNormMul");
const hn::ScalableTag<float> d;
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
constexpr float kEps = 1e-6f; // avoid divide by zero
@ -195,7 +198,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const WT* HWY_RESTRICT weight,
size_t w_ofs, OT* HWY_RESTRICT out,
const size_t size) {
PROFILER_FUNC;
PROFILER_ZONE("ops.RMSNorm");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
@ -228,7 +231,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
size_t w_ofs,
XT* HWY_RESTRICT inout,
const size_t size) {
PROFILER_FUNC;
PROFILER_ZONE("ops.RMSNormInplace");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
@ -280,7 +283,7 @@ HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size,
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
const WT* HWY_RESTRICT bias, OT* out, size_t size) {
PROFILER_FUNC;
PROFILER_ZONE("ops.LayerNorm");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
@ -360,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
}
}
/* RoPE as in Rotary Position Embeddings from the RoFormer paper
/* RoPE as in Rotary Position Embeddings from the `RoFormer` paper
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
as a function of their absolute position using the rotation matrix R before
the self-attention operation. R is a d x d matrix.
@ -391,7 +394,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
PROFILER_FUNC;
PROFILER_ZONE("ops.Rope");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
@ -409,7 +412,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
PROFILER_FUNC;
PROFILER_ZONE("ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
@ -477,15 +480,45 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
template <typename XT>
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
float* HWY_RESTRICT out,
const size_t size) {
PROFILER_ZONE("ops.AddFrom");
hn::Transform1(D(), x, size, other,
[](const auto d, const V x, const V other)
HWY_ATTR { return hn::Add(x, other); });
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
const auto packed_x = MakeSpan(x, size);
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1;
Decompress2(df, packed_x, i, x0, x1);
VF out0 = hn::Load(df, out + i);
VF out1 = hn::Load(df, out + i + NF);
hn::Store(hn::Add(x0, out0), df, out + i);
hn::Store(hn::Add(x1, out1), df, out + i + NF);
}
}
const size_t remaining = size - i;
const size_t remaining1 = remaining - HWY_MIN(remaining, NF);
HWY_DASSERT(remaining < 2 * NF);
HWY_DASSERT(remaining1 < NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF);
const VF out0 = hn::LoadN(df, out + i, remaining);
const VF out1 = hn::LoadN(df, out + i + NF, remaining1);
hn::StoreN(hn::Add(x0, out0), df, out + i, remaining);
hn::StoreN(hn::Add(x1, out1), df, out + i + NF, remaining1);
}
}
// Simple loops unless/until batch sizes are large enough to parallelize.
@ -534,17 +567,19 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
});
}
static HWY_INLINE void AddFromBatched(const MatPtrT<float>& other,
MatPtrT<float>& x) {
HWY_DASSERT(x.SameShape(other));
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols());
template <typename XT>
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
MatPtrT<float>& out) {
HWY_DASSERT(out.SameShape(x));
for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols());
}
}
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, const size_t size,
const size_t max_pos) {
PROFILER_ZONE("ops.MulBy");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
@ -563,6 +598,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
const size_t size, const size_t max_pos) {
PROFILER_ZONE("ops.MulByConst");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
@ -578,22 +614,56 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
MulByConst(c, x, size, size);
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
size_t size) {
template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
const XT* HWY_RESTRICT x,
OT* HWY_RESTRICT out,
size_t size) {
PROFILER_ZONE("ops.MulByConstAndAdd");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
hn::Transform1(D(), out, size, x,
[c](const auto d, const V v_out, const V v_x) HWY_ATTR {
return hn::MulAdd(v_x, hn::Set(d, c), v_out);
});
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
const VF v_c = hn::Set(df, c);
const auto packed_x = MakeSpan(x, size);
const auto packed_out = MakeSpan(out, size);
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1, out0, out1;
Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_out, i, out0, out1);
out0 = hn::MulAdd(x0, v_c, out0);
out1 = hn::MulAdd(x1, v_c, out1);
Compress2(df, out0, out1, packed_out, i);
}
}
const size_t remaining = size - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF);
VF out0 = hn::Load(df, buf_out);
VF out1 = hn::Load(df, buf_out + NF);
out0 = hn::MulAdd(x0, v_c, out0);
out1 = hn::MulAdd(x1, v_c, out1);
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
}
}
// See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos,
float temperature = 1.0f) {
PROFILER_ZONE("ops.Softmax");
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
@ -733,6 +803,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const size_t size,
const size_t max_pos) {
PROFILER_ZONE("ops.LogitsSoftCap");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;

View File

@ -96,6 +96,17 @@ class MatPtr : public IFields {
}
}
void AllocateAndAttachRowPtrs(
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) {
if (!HasPtr()) return;
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(Rows()));
uint8_t** ptrs = row_ptrs.back().get();
for (size_t r = 0; r < Rows(); ++r) {
ptrs[r] = RowBytes(r);
}
AttachRowPtrs(ptrs);
};
uint8_t** GetRowPtrs() const { return row_ptrs_; }
// A single row counts as packed because there is no padding between rows.
@ -328,7 +339,7 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
template <class Func, typename... Args>
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const Func& func, Args&&... args) {
HWY_ASSERT(base1->GetType() == base2->GetType());
HWY_DASSERT(base1->GetType() == base2->GetType());
#if GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kNUQ) {
@ -359,13 +370,12 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
template <class Func, typename... Args>
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
Args&&... args) {
HWY_ASSERT(base != nullptr);
if (base->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base),
std::forward<Args>(args)...);
const MatPtrT<float> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
std::forward<Args>(args)...);
const MatPtrT<BF16> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}