diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 99606ac..0677d6c 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1292,13 +1292,16 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // queries_pos are incremented by Transformer. bool all_queries_eos = true; - PROFILER_ZONE("Gen.Embedding"); - // Compute logits from last layer activations. - MatMul( - num_queries, ConstMat(activations.x.All(), kModelDim), - ConstMat(weights.embedder_input_embedding.data(), kModelDim), - weights.embedder_input_embedding.scale(), /*add=*/nullptr, - activations.env, MutableMat(activations.logits.All(), kVocabSize)); + { + PROFILER_ZONE("Gen.EmbeddingMatmul"); + // Compute logits from last layer activations. + MatMul( + num_queries, ConstMat(activations.x.All(), kModelDim), + ConstMat(weights.embedder_input_embedding.data(), kModelDim), + weights.embedder_input_embedding.scale(), /*add=*/nullptr, + activations.env, MutableMat(activations.logits.All(), kVocabSize)); + } + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); diff --git a/gemma/weights.cc b/gemma/weights.cc index 405f409..77f0628 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -76,7 +76,7 @@ struct LoadCompressedWeightsT { } { PROFILER_ZONE("Startup.Reshape"); - c_weights->Reshape(); + c_weights->Reshape(pool); } return c_weights_u8; } @@ -90,7 +90,8 @@ ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, } namespace { -void LogVec(const char* name, const float* data, size_t len) { +// For reasons unknown, this is shown as potentially unused in the IDE. +void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) { hwy::Stats stats; for (size_t i = 0; i < len; ++i) { stats.Notify(data[i]); diff --git a/gemma/weights.h b/gemma/weights.h index 42fc23b..73c1e22 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -210,10 +210,10 @@ struct CompressedWeights { explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {} // Called by weights.cc after ForEachTensor. - void Reshape() { - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + void Reshape(hwy::ThreadPool& pool) { + pool.Run(0, TConfig::kLayers, [this](uint64_t layer, size_t /*thread*/) { GetLayer(layer)->Reshape(); - } + }); } void ZeroInit() { @@ -279,7 +279,7 @@ struct ReshapeCompressedWeights { void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); - weights.Reshape(); + weights.Reshape(pool); } }; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 5b9b007..e1d1e0f 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -601,6 +601,7 @@ HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault()); } +// See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const size_t mask_pos) { HWY_DASSERT(size != 0); @@ -644,13 +645,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, Softmax(x, size, size); } -// Returns argmax of softmax and its probability. This overwrites `x`, but not -// with normalized probabilities. Only equivalent to `Softmax` + `sample_func` -// if `kTopK` == 1. This is worthwhile because `num` is -// typically `kVocabSize` == 256K, and this avoids writing that many, and then -// scanning them again for the max. -static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, - const size_t num) { +// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / +// exp / mul passes with two passes, both of which compute Exp. This is +// reportedly only faster for very large arrays, larger even than our 256K +// vocab size. We instead fuse the subsequent sampling pass into the softmax, +// which already knows the max value which top-1 sampling would again seek. + +// Returns the argmax and x[argmax]. +static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, + const size_t num) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; using V = hn::Vec; @@ -679,18 +682,36 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), vi0, argmax0); argmax1 = hn::IfThenElse(hn::RebindMask(di, gt1), vi1, argmax1); } + // Combine the two vectors const M gt0 = hn::Gt(max0, max1); max0 = hn::IfThenElse(gt0, max0, max1); argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), argmax0, argmax1); + // Reduce to the global max const V max = hn::MaxOfLanes(d, max0); // broadcasts - const V* pmax = &max; + // Argmax = lowest-indexed lane equal to the global max const size_t lane = hn::FindKnownFirstTrue(d, hn::Eq(max, max0)); const TI argmax = hn::ExtractLane(argmax0, lane); + return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)}; +} + +// Returns argmax of softmax and its probability. This overwrites `x`, but not +// with normalized probabilities. Only equivalent to `Softmax` + `sample_func` +// if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize` +// == 256K, and this avoids writing and then scanning again for the max. +static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, + const size_t num) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag d; + using V = hn::Vec; + + const TokenAndProb argmax = ArgmaxAndMax(x, num); // Subtract max (avoid precision loss for large exponents) and exponentiate. + const V max = hn::Set(d, argmax.prob); + const V* pmax = &max; hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR { if constexpr (HWY_TARGET & HWY_ALL_SVE) { // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). @@ -705,8 +726,8 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, // normalized probabilities from 1E-7 to 5E-8, but actually also changes the // generated text after a few hundred tokens. const float sum_exp = Sum(d, x, num); - const float prob = x[argmax] / sum_exp; - return TokenAndProb{.token = argmax, .prob = prob}; + const float prob = x[argmax.token] / sum_exp; + return TokenAndProb{.token = argmax.token, .prob = prob}; } static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,