diff --git a/BUILD.bazel b/BUILD.bazel index da42e58..2792314 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -214,6 +214,7 @@ cc_library( hdrs = ["gemma/weights.h"], deps = [ ":configs", + ":gemma_args", ":mat", ":model_store", ":ops", @@ -223,9 +224,7 @@ cc_library( "//io:blob_store", "@highway//:hwy", "@highway//:profiler", - "@highway//:stats", "@highway//:thread_pool", - "@highway//:timer", ], ) @@ -500,7 +499,6 @@ cc_library( # Placeholder for internal dep, do not remove., "//io:blob_store", "//io", - "//compression:types", "//paligemma:image", "@highway//:hwy", "@highway//:nanobenchmark", # timer diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 2c203f2..6ebf930 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -49,7 +49,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) - : env_(MakeMatMulEnv(threading)), gemma_(loader, env_) { + : env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) { const ModelConfig& config = gemma_.GetModelConfig(); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size)); @@ -229,8 +229,9 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, threading.Print(inference.verbosity); loader.Print(inference.verbosity); inference.Print(inference.verbosity); - fprintf(stderr, "Model : %s, mmap %d\n", - config.Specifier().c_str(), static_cast(loader.map)); + fprintf(stderr, "Model : %s, to_bf16 %d, mmap %d\n", + config.Specifier().c_str(), static_cast(loader.to_bf16), + static_cast(loader.map)); if (inference.verbosity >= 2) { time_t now = time(nullptr); diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 5411d95..9d6f827 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -37,7 +37,7 @@ class GemmaTest : public ::testing::Test { protected: std::vector BatchGemmaReply( const std::vector& inputs) { - s_env->SetMaxGeneratedTokens(32); + s_env->SetMaxGeneratedTokens(16); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 2; std::vector replies; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 683ff88..faf3f42 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -52,7 +52,7 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache gcpp::MatMulEnv env(MakeMatMulEnv(threading)); - gcpp::Gemma gemma(loader, env); + gcpp::Gemma gemma(loader, inference, env); gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size); size_t generated = 0; diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 0899096..2f6f5be 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -39,7 +39,7 @@ class SimplifiedGemma { threading_(threading), inference_(inference), env_(MakeMatMulEnv(threading_)), - gemma_(loader_, env_), + gemma_(loader_, inference_, env_), kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) { // Initialize random number generator std::random_device rd; diff --git a/gemma/activations.h b/gemma/activations.h index 0c012dc..7a94960 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -88,23 +88,14 @@ struct Activations { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. - const auto init_row_ptrs = [&](MatPtrT& mat) { - if (!mat.HasPtr()) return; - row_ptrs.push_back(hwy::AllocateAligned(mat.Rows())); - uint8_t** ptrs = row_ptrs.back().get(); - for (size_t r = 0; r < mat.Rows(); ++r) { - ptrs[r] = mat.RowBytes(r); - } - mat.AttachRowPtrs(ptrs); - }; // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. - init_row_ptrs(q); - init_row_ptrs(logits); - init_row_ptrs(att_sums); - init_row_ptrs(C1); - init_row_ptrs(C2); - init_row_ptrs(ffw_out); + q.AllocateAndAttachRowPtrs(row_ptrs); + logits.AllocateAndAttachRowPtrs(row_ptrs); + att_sums.AllocateAndAttachRowPtrs(row_ptrs); + C1.AllocateAndAttachRowPtrs(row_ptrs); + C2.AllocateAndAttachRowPtrs(row_ptrs); + ffw_out.AllocateAndAttachRowPtrs(row_ptrs); // TODO: also init rows for image_tokens. // Note that BindC on any MatMul output considerably slows down Prefill. @@ -142,7 +133,7 @@ struct Activations { const MatPadding pad_ = MatPadding::kOdd; MatStorageT x; // input - MatStorageT q; // query, also KV if MHA. + MatStorageT q; // query MatStorageT logits; // Attention @@ -150,13 +141,13 @@ struct Activations { MatStorageT att; // attention vector MatStorageT att_out; // attention output // Accumulation of attention outputs over heads - MatStorageT att_sums; + MatStorageT att_sums; // Gated FFW MatStorageT pre_ffw_rms_out; - MatStorageT C1; + MatStorageT C1; // TODO: BF16 after Activation() supports it MatStorageT C2; - MatStorageT ffw_out; + MatStorageT ffw_out; // Griffin MatStorageT griffin_x; diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index d81f0b8..47f4ad7 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -109,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, threading_args(threading_args), matmul_env(MakeMatMulEnv(threading_args)), active_conversation_name("default"), - model(loader, matmul_env) { + model(loader, inference_args, matmul_env) { std::stringstream ss; LogDebug("Creating initial ConversationData"); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index d215596..5fbf85a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -210,17 +210,9 @@ static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, } // interleaved_idx // Final linear layer. - // TODO: MatMul - CallUpcasted( - &layer_weights->griffin.linear_out_w, [&](const auto* weights_t) { - for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT x = activations.griffin_x.Row(r); - float* out_ptr = activations.att_sums.Row(r); - MatVecAdd(*weights_t, 0, model_dim, model_dim, x, - layer_weights->griffin.linear_out_biases.PackedScale1(), - out_ptr, pool); - } - }); + CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, + layer_weights->griffin.linear_out_biases.PackedScale1(), + *activations.env, activations.att_sums); } // GriffinRecurrent // Wrapper class; holds arguments in member variables to shorten call sites. @@ -1142,7 +1134,10 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, } }); // Add position embeddings. - AddFromBatched(weights.vit_img_pos_embedding, activations.x); + CallUpcastedActivation(&weights.vit_img_pos_embedding, + [&](const auto* weights_t) { + AddFromBatched(*weights_t, activations.x); + }); } // Prefills the image tokens with the ViT encoder. diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 36f97a9..bd37dc5 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -68,13 +68,14 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { return MatMulEnv(ThreadingContext::Get()); } -Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env) +Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, + MatMulEnv& env) : env_(env), reader_(loader.weights), model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model) { - weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_, + weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, env.ctx.pools.Pool()); reader_.CloseFile(); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 66a8531..18018c8 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -105,7 +105,8 @@ class Gemma { public: // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. // `env` must remain valid for the lifetime of this Gemma. - Gemma(const LoaderArgs& loader, MatMulEnv& env); + Gemma(const LoaderArgs& loader, const InferenceArgs& inference, + MatMulEnv& env); ~Gemma(); diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 478470f..b842d44 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -52,6 +52,7 @@ struct LoaderArgs : public ArgsBase { Path tokenizer; Path weights; // weights file location Tristate map; + Tristate to_bf16; Tristate wrapping; template @@ -62,6 +63,8 @@ struct LoaderArgs : public ArgsBase { "Path name of model weights (.sbs) file.\n Required argument.\n"); visitor(map, "map", Tristate::kDefault, "Enable memory-mapping? -1 = auto, 0 = no, 1 = yes."); + visitor(to_bf16, "to_bf16", Tristate::kDefault, + "Convert weights to bf16? -1 = auto, 0 = no, 1 = yes."); visitor(wrapping, "wrapping", Tristate::kDefault, "Enable prompt wrapping? Specify 0 for pre-2025 format PT models."); } diff --git a/gemma/run.cc b/gemma/run.cc index cba3500..2afbecb 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, MatMulEnv env(MakeMatMulEnv(threading)); if (inference.verbosity >= 2) env.print_best = true; - const Gemma gemma(loader, env); + const Gemma gemma(loader, inference, env); KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size); if (inference.verbosity >= 1) { diff --git a/gemma/weights.cc b/gemma/weights.cc index ab2c679..4c1f482 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -27,6 +27,7 @@ #include "compression/compress.h" #include "compression/types.h" #include "gemma/configs.h" +#include "gemma/gemma_args.h" #include "gemma/model_store.h" #include "io/blob_store.h" #include "ops/matmul.h" // MMParallel @@ -37,7 +38,7 @@ #include "hwy/highway.h" #include "hwy/profiler.h" -// TODO: move into foreach_target; this is only used for NUQ Fixup. +// TODO: move into foreach_target #include "compression/compress-inl.h" namespace gcpp { @@ -246,16 +247,80 @@ std::vector ModelWeightsPtrs::AddTensorDataToWriter( return serialized_mat_ptrs; } +enum class Mode { + // Parallel I/O, decompress to BF16. Best for large batch sizes. + kReadBF16, + // Parallel I/O, insert row-wise padding. Safe default. + kRead, + // Best for large weights relative to available memory, especially for + // frequent invocations of small batches and short sequences. Adds noise to + // performance measurements due to I/O variability. + kMap +}; + +// Decides whether to read or map based on heuristics and user override. +static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader, + const InferenceArgs& inference) { + Tristate to_bf16 = loader.to_bf16; + Tristate map = loader.map; + + // Disable mapping if not padded to the base page size. + const Allocator& allocator = ThreadingContext::Get().allocator; + if (file_bytes % allocator.BasePageBytes() != 0) { + if (map != Tristate::kFalse) { // Do not complain if anyway disabled. + HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.", + static_cast(file_bytes >> 10), + allocator.BasePageBytes()); + map = Tristate::kFalse; + } + } + + // Check for user override: + if (to_bf16 == Tristate::kTrue && map == Tristate::kTrue) { + HWY_WARN("Cannot have to_bf16 && map, to_bf16 takes precedence."); + } + if (to_bf16 == Tristate::kTrue) return Mode::kReadBF16; + if (map == Tristate::kTrue) return Mode::kMap; + + if (to_bf16 == Tristate::kDefault) { + // Heuristic: sub-bf16 compression is not helpful if compute-bound. + const size_t batch_size = + HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size); + to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse; + } + + if (map == Tristate::kDefault) { + // Heuristic: map if large fraction of total. Do not decide based on + // `FreeMiB` because it is generally low. + const size_t file_mib = file_bytes >> 20; + const size_t total_mib = allocator.TotalMiB(); + if (file_mib > total_mib) { + HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.", + static_cast(file_mib), total_mib); + } + // Large fraction of total. + map = (file_mib >= total_mib / 3) ? Tristate::kTrue : Tristate::kFalse; + } + + // If the `map` heuristic triggers, use that for safety. + if (map == Tristate::kTrue) return Mode::kMap; + return (to_bf16 == Tristate::kTrue) ? Mode::kReadBF16 : Mode::kRead; +} + struct TensorToRead { MatPtr* mat; BlobRange range; // Some tensors opt out of padding via kPacked flags. MatPadding padding; + + // only for kReadBF16 + bool keep_type = false; + Type prev_type; }; // Allocates multiple in parallel and binds to NUMA nodes. -static void AllocateAndBindAll(const std::vector& tensors, - std::vector& owners, +static void AllocateAndBindAll(std::vector& tensors, + const Mode mode, std::vector& owners, hwy::ThreadPool& pool) { const size_t start = owners.size(); owners.resize(start + tensors.size()); @@ -264,52 +329,25 @@ static void AllocateAndBindAll(const std::vector& tensors, // Allocate in parallel because faulting in large tensors is slow. pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { - owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding); + TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; + + tensor.prev_type = mat.GetType(); + // We only care about MatMul inputs; skip F32 or small tensors. + if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) { + tensor.keep_type = true; + tensor.padding = MatPadding::kPacked; // single I/O for simplicity + } else if (mode == Mode::kReadBF16) { + mat.SetType(Type::kBF16); + } + + owners[start + task].AllocateFor(*tensor.mat, tensor.padding); // TODO(janwas): MatMul outputs will later also be BF16. - BindB(*tensors[task].mat, sizeof(float), parallel); + BindB(*tensor.mat, sizeof(float), parallel); }); } -// Parallel I/O into allocated memory, or mapped view of file. The latter is -// better when the file is huge, but page faults add noise to measurements. -enum class Mode { kRead, kMap }; - -// Decides whether to read or map based on heuristics and user override. -static Mode ChooseMode(uint64_t file_bytes, Tristate map) { - const Allocator& allocator = ThreadingContext::Get().allocator; - // User has explicitly requested a map or read via args. - if (map == Tristate::kTrue) return Mode::kMap; - if (map == Tristate::kFalse) return Mode::kRead; - // Else: use heuristics to choose. Note that `FreeMiB` is generally low - // because idle memory is used as cache, so do not use it to decide. - const size_t file_mib = file_bytes >> 20; - const size_t total_mib = allocator.TotalMiB(); - if (file_mib > total_mib) { - HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.", - static_cast(file_mib), total_mib); - } - // Large fraction of total. - if (file_mib >= total_mib / 3) return Mode::kMap; - // Big enough that even parallel loading wouldn't be quick. - if (file_mib > 50 * 1024) return Mode::kMap; - return Mode::kRead; -} - -static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) { - const Allocator& allocator = ThreadingContext::Get().allocator; - if (file_bytes % allocator.BasePageBytes() == 0) { - MapPtr mapped = file.Map(); - if (!mapped) { - HWY_WARN("Failed to map file (%zu KiB), reading instead.", - static_cast(file_bytes >> 10)); - } - } else { - HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.", - static_cast(file_bytes >> 10), allocator.BasePageBytes()); - } - return MapPtr(); -} - +// Mode == kMap static void MapAll(const std::vector& tensors, const MapPtr& mapped) { PROFILER_ZONE("Startup.Weights.Map"); @@ -326,6 +364,65 @@ static void MapAll(const std::vector& tensors, } } +// Mode == kReadBF16: + +template +static void DecompressToBF16(const MatPtr& mat, + const hwy::AlignedFreeUniquePtr& buf) { + hwy::HWY_NAMESPACE::ScalableTag dbf; + const size_t cols = mat.Cols(); + + const size_t num_packed = CompressedArrayElements(mat.Extents().Area()); + const PackedSpan packed{HWY_RCAST_ALIGNED(T*, buf.get()), num_packed}; + + size_t packed_ofs = 0; + for (size_t r = 0; r < mat.Rows(); ++r, packed_ofs += cols) { + HWY_NAMESPACE::DecompressAndZeroPad( + dbf, packed, packed_ofs, HWY_RCAST_ALIGNED(BF16*, mat.RowBytes(r)), + cols); + } +} + +static void ReadAllToBF16(const std::vector& tensors, + const BlobReader& reader, hwy::ThreadPool& pool) { + PROFILER_ZONE("Startup.Weights.ReadBF16"); + + pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { + const TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; + + if (tensor.keep_type) { + HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes, + mat.Packed())); + return; + } + + // Read to a temporary buffer. + const hwy::AlignedFreeUniquePtr buf = + hwy::AllocateAligned(tensor.range.bytes); + HWY_ASSERT( + reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get())); + + if constexpr (GEMMA_ENABLE_NUQ) { + if (tensor.prev_type == Type::kNUQ) { + return DecompressToBF16(*tensor.mat, buf); + } + } + switch (tensor.prev_type) { + case Type::kF32: + return DecompressToBF16(*tensor.mat, buf); + case Type::kBF16: + return DecompressToBF16(*tensor.mat, buf); + case Type::kSFP: + return DecompressToBF16(*tensor.mat, buf); + default: + HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type)); + } + }); +} + +// Mode == kRead: + static std::vector MakeBatches( const std::vector& tensors, const uint64_t file_bytes) { PROFILER_ZONE("Startup.Weights.MakeBatches"); @@ -382,31 +479,36 @@ static void ReadBatches(const BlobReader& reader, } // Aborts on error. -static void MapOrReadAll(const std::vector& tensors, - BlobReader& reader, Tristate map, - std::vector& mat_owners, +static void MapOrReadAll(std::vector& tensors, BlobReader& reader, + Mode mode, std::vector& mat_owners, hwy::ThreadPool& pool) { - if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) { - MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes()); - if (mapped) { - MapAll(tensors, mapped); - return; - } - } // otherwise fall through to read mode + if (mode == Mode::kMap) { + MapPtr mapped = reader.file().Map(); + if (mapped) return MapAll(tensors, mapped); + HWY_WARN("Failed to map file (%zu KiB), reading instead.", + static_cast(reader.file_bytes() >> 10)); + // If we wanted to map but failed, memory is probably not plentiful, so + // fall through to kRead because kReadBF16 requires more memory. + mode = Mode::kRead; + } { PROFILER_ZONE("Startup.Weights.Allocate"); // NOTE: this changes the stride of `mats`! - AllocateAndBindAll(tensors, mat_owners, pool); + AllocateAndBindAll(tensors, mode, mat_owners, pool); } + if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool); + const std::vector batches = MakeBatches(tensors, reader.file_bytes()); ReadBatches(reader, batches, pool); } void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model, - BlobReader& reader, Tristate map, + BlobReader& reader, + const LoaderArgs& loader, + const InferenceArgs& inference, std::vector& mat_owners, hwy::ThreadPool& pool) { // List of tensors to read/map, and where from. @@ -427,7 +529,9 @@ void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model, HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); }); - MapOrReadAll(tensors, reader, map, mat_owners, pool); + const Mode mode = ChooseMode(reader.file_bytes(), loader, inference); + + MapOrReadAll(tensors, reader, mode, mat_owners, pool); { PROFILER_ZONE("Startup.Fixup"); diff --git a/gemma/weights.h b/gemma/weights.h index ba20627..ac26340 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -24,6 +24,7 @@ #include "compression/types.h" #include "gemma/configs.h" // ModelConfig +#include "gemma/gemma_args.h" // InferenceArgs #include "gemma/model_store.h" // ModelStore #include "gemma/tensor_info.h" // TensorInfoRegistry #include "io/blob_store.h" // BlobWriter @@ -424,7 +425,8 @@ struct ModelWeightsPtrs { // Reads tensor data from `BlobStore` or aborts on error. `map` is a user // override for whether to map blobs or read them. - void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map, + void ReadFromBlobs(const ModelStore& model, BlobReader& reader, + const LoaderArgs& loader, const InferenceArgs& inference, std::vector& mat_owners, hwy::ThreadPool& pool); // Adds one blob for each tensor's data and returns all serialized MatPtr. diff --git a/io/BUILD.bazel b/io/BUILD.bazel index 1811b4e..cd02c78 100644 --- a/io/BUILD.bazel +++ b/io/BUILD.bazel @@ -73,7 +73,6 @@ cc_library( "//:basics", "//:threading_context", "@highway//:hwy", - "@highway//:profiler", "@highway//:thread_pool", ], ) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index cbb76a1..0d7b664 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -19,10 +19,6 @@ #include -#pragma push_macro("PROFILER_ENABLED") -#undef PROFILER_ENABLED -#define PROFILER_ENABLED 0 - #include "compression/types.h" #include "ops/matmul.h" // IWYU pragma: export #include "util/allocator.h" @@ -952,15 +948,11 @@ class MMPerPackage { range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, [&](const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(B_storage, K, B_stride); + const RowPtrBF B_storage_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - { - MMZone zone; - zone.MaybeEnter("MM.NT.DecB", args_); - DecompressB(B, row_b, range_K, B_view); - } + RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), args_, C_rows); } @@ -985,17 +977,13 @@ class MMPerPackage { auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); - const RowPtrBF B_view( + const RowPtrBF B_storage_view( B_storage, kc, Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - { - MMZone zone; - zone.MaybeEnter("MM.NT_K.DecB", args_); - DecompressB(B, row_b, range_kc, B_view); - } + RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, C_rows); } @@ -1051,15 +1039,11 @@ class MMPerPackage { [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(B_storage, K, B_stride); + const RowPtrBF B_storage_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - { - MMZone zone; - zone.MaybeEnter("MM.NT_MT.DecB", args_); - DecompressB(B, row_b, range_K, B_view); - } + RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), args_, C_rows); } @@ -1081,7 +1065,8 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. - const auto loop_nc = [&](const RowPtrBF& B_view, const IndexRange& range_mc, + const auto loop_nc = [&](const RowPtrBF& B_storage_view, + const IndexRange& range_mc, const IndexRange& range_kc, const IndexRange& range_nc, auto out_tag) HWY_ATTR { @@ -1090,11 +1075,7 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - { - MMZone zone; - zone.MaybeEnter("MM.NT_MT_K.DecB", args_); - DecompressB(B, row_b, range_kc, B_view); - } + RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, C_rows); } @@ -1103,15 +1084,17 @@ class MMPerPackage { ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(B_storage, kc_max, B_stride); + const RowPtrBF B_storage_view(B_storage, kc_max, B_stride); // Peel off the first iteration of the kc loop: avoid // zero-initializing `partial` by writing into it. ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_view, range_mc, range_kc, range_nc, MMSetPartial()); + loop_nc(B_storage_view, range_mc, range_kc, range_nc, + MMSetPartial()); }); ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_view, range_mc, range_kc, range_nc, MMAddPartial()); + loop_nc(B_storage_view, range_mc, range_kc, range_nc, + MMAddPartial()); }); // Already in parallel section, hence no `kParM`, and @@ -1228,11 +1211,15 @@ class MMPerPackage { // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // thanks to its large table lookups, and less so on other targets. template - HWY_INLINE void DecompressB(const MatPtrT& B, const size_t row_b, - const IndexRange& range_kc, - const RowPtrBF& B_view) const { - const hn::ScalableTag dbf; + HWY_INLINE RowPtrBF DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, + const RowPtrBF& B_view) const { + if constexpr (hwy::IsSame()) { + return RowPtrBF(const_cast(B.Row(row_b)) + range_kc.begin(), + range_kc.Num(), B.Stride()); + } + const hn::ScalableTag dbf; const PackedSpan B_span = B.PaddedSpan(); const size_t kc = range_kc.Num(); @@ -1249,6 +1236,7 @@ class MMPerPackage { } } } + return B_view; } const MMArgs args_; // copy for locality @@ -1421,5 +1409,3 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_AFTER_NAMESPACE(); #endif // NOLINT - -#pragma pop_macro("PROFILER_ENABLED") diff --git a/ops/matmul.h b/ops/matmul.h index 06cb3f1..c82956c 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -567,6 +567,16 @@ class MMAutoTune { // Map of previously seen dimensions to index via linear search. class MMKeys { + // Group batch size into buckets to reduce #auto-tunes. + static size_t BucketM(size_t M) { + // The first 4 may require their own bucket because `kNT` only works for a + // single M range, but that depends on the config's `MR()`. + if (M <= 4) return M; + if (M <= 16) return 16; + if (M <= 64) return 64; + return 256; + } + public: using Key = uint64_t; // KeyFromDims will only return this if all dims are zero, which is invalid. @@ -577,7 +587,7 @@ class MMKeys { HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller HWY_DASSERT(K < (Key{1} << 24)); HWY_DASSERT(N < (Key{1} << 24)); - const Key key = static_cast(M) | (static_cast(K) << 16) | + const Key key = static_cast(BucketM(M)) | (static_cast(K) << 16) | (static_cast(N) << 40); HWY_DASSERT(key != kPadding); return key; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index ea27c68..219c006 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -117,6 +117,7 @@ HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, size_t size) { + PROFILER_ZONE("ops.Gelu"); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; hn::Transform(D(), x, size, @@ -181,6 +182,8 @@ namespace detail { // Shared by RMSNorm and RMSNormInplace. template float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { + PROFILER_ZONE("ops.RMSNormMul"); + const hn::ScalableTag d; const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); constexpr float kEps = 1e-6f; // avoid divide by zero @@ -195,7 +198,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs, OT* HWY_RESTRICT out, const size_t size) { - PROFILER_FUNC; + PROFILER_ZONE("ops.RMSNorm"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -228,7 +231,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, const size_t size) { - PROFILER_FUNC; + PROFILER_ZONE("ops.RMSNormInplace"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -280,7 +283,7 @@ HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size, template HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale, const WT* HWY_RESTRICT bias, OT* out, size_t size) { - PROFILER_FUNC; + PROFILER_ZONE("ops.LayerNorm"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -360,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( } } -/* RoPE as in Rotary Position Embeddings from the RoFormer paper +/* RoPE as in Rotary Position Embeddings from the `RoFormer` paper (https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated as a function of their absolute position using the rotation matrix R before the self-attention operation. R is a d x d matrix. @@ -391,7 +394,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, int pos) { - PROFILER_FUNC; + PROFILER_ZONE("ops.Rope"); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { @@ -409,7 +412,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const float mul, float* HWY_RESTRICT x, size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, int pos) { - PROFILER_FUNC; + PROFILER_ZONE("ops.RopeAndMulBy"); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -477,15 +480,45 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( } } -static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( - const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; +template +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, + float* HWY_RESTRICT out, + const size_t size) { + PROFILER_ZONE("ops.AddFrom"); - hn::Transform1(D(), x, size, other, - [](const auto d, const V x, const V other) - HWY_ATTR { return hn::Add(x, other); }); + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + + const auto packed_x = MakeSpan(x, size); + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + VF x0, x1; + Decompress2(df, packed_x, i, x0, x1); + VF out0 = hn::Load(df, out + i); + VF out1 = hn::Load(df, out + i + NF); + hn::Store(hn::Add(x0, out0), df, out + i); + hn::Store(hn::Add(x1, out1), df, out + i + NF); + } + } + + const size_t remaining = size - i; + const size_t remaining1 = remaining - HWY_MIN(remaining, NF); + HWY_DASSERT(remaining < 2 * NF); + HWY_DASSERT(remaining1 < NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; + DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); + const VF x0 = hn::Load(df, buf_x); + const VF x1 = hn::Load(df, buf_x + NF); + const VF out0 = hn::LoadN(df, out + i, remaining); + const VF out1 = hn::LoadN(df, out + i + NF, remaining1); + hn::StoreN(hn::Add(x0, out0), df, out + i, remaining); + hn::StoreN(hn::Add(x1, out1), df, out + i + NF, remaining1); + } } // Simple loops unless/until batch sizes are large enough to parallelize. @@ -534,17 +567,19 @@ void LayerNormBatched(const MatPtrT& x, const MatPtr& weight, }); } -static HWY_INLINE void AddFromBatched(const MatPtrT& other, - MatPtrT& x) { - HWY_DASSERT(x.SameShape(other)); - for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) { - AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols()); +template +static HWY_INLINE void AddFromBatched(const MatPtrT& x, + MatPtrT& out) { + HWY_DASSERT(out.SameShape(x)); + for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) { + AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols()); } } static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size, const size_t max_pos) { + PROFILER_ZONE("ops.MulBy"); HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -563,6 +598,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other, static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x, const size_t size, const size_t max_pos) { + PROFILER_ZONE("ops.MulByConst"); HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -578,22 +614,56 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, MulByConst(c, x, size, size); } -static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( - float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out, - size_t size) { +template +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c, + const XT* HWY_RESTRICT x, + OT* HWY_RESTRICT out, + size_t size) { + PROFILER_ZONE("ops.MulByConstAndAdd"); namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - hn::Transform1(D(), out, size, x, - [c](const auto d, const V v_out, const V v_x) HWY_ATTR { - return hn::MulAdd(v_x, hn::Set(d, c), v_out); - }); + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + + const VF v_c = hn::Set(df, c); + const auto packed_x = MakeSpan(x, size); + const auto packed_out = MakeSpan(out, size); + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + VF x0, x1, out0, out1; + Decompress2(df, packed_x, i, x0, x1); + Decompress2(df, packed_out, i, out0, out1); + out0 = hn::MulAdd(x0, v_c, out0); + out1 = hn::MulAdd(x1, v_c, out1); + Compress2(df, out0, out1, packed_out, i); + } + } + + const size_t remaining = size - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; + DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); + DecompressAndZeroPad(df, packed_out, i, buf_out, remaining); + const VF x0 = hn::Load(df, buf_x); + const VF x1 = hn::Load(df, buf_x + NF); + VF out0 = hn::Load(df, buf_out); + VF out1 = hn::Load(df, buf_out + NF); + out0 = hn::MulAdd(x0, v_c, out0); + out1 = hn::MulAdd(x1, v_c, out1); + Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); + hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT)); + } } // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const size_t mask_pos, float temperature = 1.0f) { + PROFILER_ZONE("ops.Softmax"); HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); @@ -733,6 +803,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, const size_t size, const size_t max_pos) { + PROFILER_ZONE("ops.LogitsSoftCap"); HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; diff --git a/util/mat.h b/util/mat.h index f350ce7..a61bedd 100644 --- a/util/mat.h +++ b/util/mat.h @@ -96,6 +96,17 @@ class MatPtr : public IFields { } } + void AllocateAndAttachRowPtrs( + std::vector>& row_ptrs) { + if (!HasPtr()) return; + row_ptrs.push_back(hwy::AllocateAligned(Rows())); + uint8_t** ptrs = row_ptrs.back().get(); + for (size_t r = 0; r < Rows(); ++r) { + ptrs[r] = RowBytes(r); + } + AttachRowPtrs(ptrs); + }; + uint8_t** GetRowPtrs() const { return row_ptrs_; } // A single row counts as packed because there is no padding between rows. @@ -328,7 +339,7 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, template decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, const Func& func, Args&&... args) { - HWY_ASSERT(base1->GetType() == base2->GetType()); + HWY_DASSERT(base1->GetType() == base2->GetType()); #if GEMMA_ENABLE_NUQ if (base1->GetType() == Type::kNUQ) { @@ -359,13 +370,12 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, template decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func, Args&&... args) { - HWY_ASSERT(base != nullptr); if (base->GetType() == Type::kF32) { - return func(dynamic_cast*>(base), - std::forward(args)...); + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else if (base->GetType() == Type::kBF16) { - return func(dynamic_cast*>(base), - std::forward(args)...); + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); }