diff --git a/BUILD.bazel b/BUILD.bazel index 51882ea..558c4e7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -285,6 +285,7 @@ cc_library( deps = [ ":common", ":gemma_lib", + ":ops", "@hwy//:hwy", ], ) diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 566ab85..39600b5 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -13,18 +13,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "evals/cross_entropy.h" +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "evals/cross_entropy.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "ops/ops-inl.h" // Softmax + +#ifndef GEMMA_CROSS_ENTROPY_ONCE +#define GEMMA_CROSS_ENTROPY_ONCE #include #include -#include +#include // std::sort #include #include // NOLINT #include #include #include +#include "evals/cross_entropy.h" #include "gemma/common.h" #include "gemma/gemma.h" #include "hwy/base.h" @@ -63,10 +76,30 @@ void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len, } } } // namespace +} // namespace gcpp +#endif // GEMMA_CROSS_ENTROPY_ONCE + +// SIMD code, compiled once per target. +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) { + Softmax(logits, vocab_size); +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(CallSoftmax); float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, - const std::vector& prompt, - KVCache& kv_cache, int verbosity) { + const std::vector& prompt, KVCache& kv_cache, + int verbosity) { const StreamFunc stream_token = [](int /*token*/, float) { return true; }; // TWeight is unused, but we have to pass it to Config*. @@ -74,8 +107,11 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, CallForModel(gemma.Info().model); float cross_entropy = std::log(vocab_size); // first token size_t pos = 1; - const SampleFunc sample_token = [&](const float* probs, - size_t vocab_size) -> int { + + const SampleFunc sample_token = [&](float* probs, + size_t vocab_size) -> TokenAndProb { + // input is logits, not yet probabilities + HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size); // We are called for each token, but pos starts at 1. Clamping max_tokens // to prompt.size() should prevent overrun. HWY_ASSERT(pos < prompt.size()); @@ -96,8 +132,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, cross_entropy / std::log(2.0) / (pos + 1)); } ++pos; - return token; + return TokenAndProb{.token = token, .prob = prob}; }; + std::vector prompt0 = { prompt[0] }; max_tokens = HWY_MIN(max_tokens, prompt.size()); RuntimeConfig runtime = { @@ -118,3 +155,4 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, } } // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index f49c959..99606ac 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -15,20 +15,10 @@ // SIMD functions for Gemma/Griffin transformers. -// Include guard (still compiled once per target) -#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \ - defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ -#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ -#else -#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ -#endif - #include #include #include // std::min -#include #include #include @@ -38,9 +28,6 @@ #include "gemma/gemma.h" #include "gemma/weights.h" // Placeholder for internal test4, do not remove -#include "ops/matmul-inl.h" -#include "ops/matvec-inl.h" -#include "ops/ops-inl.h" #include "paligemma/image.h" #include "util/allocator.h" #include "util/threading.h" @@ -48,10 +35,24 @@ #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/profiler.h" #include "hwy/timer.h" +// Include guard (still compiled once per target) +#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#else +#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#endif + +#include "hwy/highway.h" +// After highway.h +#include "ops/matmul-inl.h" +#include "ops/matvec-inl.h" +#include "ops/ops-inl.h" +#include "hwy/profiler.h" // also uses SIMD + #ifndef GEMMA_CONFIG #if HWY_IDE // Provide a definition so the IDE does not complain. @@ -1165,6 +1166,32 @@ class TokenStreamer { hwy::BitSet4096<> is_eos_; }; +template +SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { + constexpr size_t kTopK = TConfig::kTopK; + + // If user provided a sample_func, use it. + if (runtime_config.sample_func) return runtime_config.sample_func; + + // Fast path for top-1 with no accept_token. + if (kTopK == 1 && !runtime_config.accept_token) { + PROFILER_ZONE("Gen.Sample Top1"); + return [](float* logits, size_t vocab_size) -> TokenAndProb { + return Top1OfSoftmax(logits, vocab_size); + }; + } + + // General case: Softmax with top-k sampling. + return [&runtime_config](float* logits, size_t vocab_size) -> TokenAndProb { + PROFILER_ZONE("Gen.Sample general"); + Softmax(logits, vocab_size); + const int token = SampleTopK(logits, vocab_size, *runtime_config.gen, + runtime_config.temperature, + runtime_config.accept_token); + return TokenAndProb{.token = token, .prob = logits[token]}; + }; +} + // Generates one continuation for each query in `queries_prompt`, which is one // qbatch whose size is at most the `batch_size` passed to // `activations.Allocate`. @@ -1214,18 +1241,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, } } - // If no sample_func is provided, we use top-k sampling. - const SampleFunc sample_token = - runtime_config.sample_func - ? runtime_config.sample_func - : [&](const float* logits, size_t vocab_size) -> int { - return SampleTopK(logits, vocab_size, *runtime_config.gen, - runtime_config.temperature, - runtime_config.accept_token); - }; + const SampleFunc sample_token = ChooseSampleFunc(runtime_config); - // Prefill stops before min_prompt_size - 1 because the last prompt token is - // the first input token for generation. + // Prefill stops before min_prompt_size - 1 because the last prompt + // token is the first input token for generation. const double prefill_start = hwy::platform::Now(); // If tbatch is larger than the qbatch we already have in `activations`, then // allocate prefill_activations, otherwise reuse. @@ -1283,15 +1302,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); - Softmax(logits, kVocabSize); - const int token = sample_token(logits, kVocabSize); + const TokenAndProb tp = sample_token(logits, kVocabSize); timing_info.NotifyGenerated(prefill_start, gen_start); const bool is_eos = token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], token, logits[token]); + queries_mutable_pos[query_idx], tp.token, tp.prob); all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; + gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; } if (all_queries_eos) break; } // foreach token to generate diff --git a/gemma/gemma.h b/gemma/gemma.h index f05cab5..ea25281 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -57,9 +57,10 @@ using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; -// If not empty, SampleFunc is called with the probability distribution for the -// next token, and its return value is used as the next generated token. -using SampleFunc = std::function; +// If not empty, SampleFunc is called with the logits for the next token, which +// it may modify/overwrite, and its return value is the next generated token +// together with its probability. +using SampleFunc = std::function; // If not empty, LayersOutputFunc is called for layer outputs, specified with: // - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence diff --git a/ops/ops-inl.h b/ops/ops-inl.h index ef4b48e..31b3b00 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -28,10 +28,10 @@ #include // std::enable_if_t #include "compression/compress.h" +#include "util/allocator.h" // TokenAndProb #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" -#include "hwy/profiler.h" #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ // Include guard for (potentially) SIMD code. @@ -46,6 +46,7 @@ #include "ops/dot-inl.h" #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/math/math-inl.h" +#include "hwy/profiler.h" // also uses SIMD HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -602,7 +603,6 @@ HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const size_t mask_pos) { - PROFILER_FUNC; HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); @@ -644,6 +644,71 @@ static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, Softmax(x, size, size); } +// Returns argmax of softmax and its probability. This overwrites `x`, but not +// with normalized probabilities. Only equivalent to `Softmax` + `sample_func` +// if `kTopK` == 1. This is worthwhile because `num` is +// typically `kVocabSize` == 256K, and this avoids writing that many, and then +// scanning them again for the max. +static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, + const size_t num) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + using M = hn::Mask; + const D d; + const hn::RebindToSigned di; + using TI = hn::TFromD; + using VI = hn::Vec; + const size_t N = hn::Lanes(d); + HWY_ASSERT(num % (2 * N) == 0); + + V max0 = hn::Set(d, hwy::LowestValue()); + V max1 = max0; + VI argmax0 = hn::Zero(di); + VI argmax1 = argmax0; + + for (size_t i = 0; i < num; i += 2 * N) { + const V v0 = hn::LoadU(d, x + i); + const V v1 = hn::LoadU(d, x + i + N); + const VI vi0 = hn::Iota(di, static_cast(i)); + const VI vi1 = hn::Iota(di, static_cast(i + N)); + const M gt0 = hn::Gt(v0, max0); + const M gt1 = hn::Gt(v1, max1); + max0 = hn::IfThenElse(gt0, v0, max0); + max1 = hn::IfThenElse(gt1, v1, max1); + argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), vi0, argmax0); + argmax1 = hn::IfThenElse(hn::RebindMask(di, gt1), vi1, argmax1); + } + // Combine the two vectors + const M gt0 = hn::Gt(max0, max1); + max0 = hn::IfThenElse(gt0, max0, max1); + argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), argmax0, argmax1); + // Reduce to the global max + const V max = hn::MaxOfLanes(d, max0); // broadcasts + const V* pmax = &max; + // Argmax = lowest-indexed lane equal to the global max + const size_t lane = hn::FindKnownFirstTrue(d, hn::Eq(max, max0)); + const TI argmax = hn::ExtractLane(argmax0, lane); + + // Subtract max (avoid precision loss for large exponents) and exponentiate. + hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR { + if constexpr (HWY_TARGET & HWY_ALL_SVE) { + // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); + } else { + return hn::Exp(d, hn::Sub(value, *pmax)); + } + }); + + // Normalize to a single probability. The exact sum seems like it should not + // make a huge difference. It halves the standard deviation of the sum of the + // normalized probabilities from 1E-7 to 5E-8, but actually also changes the + // generated text after a few hundred tokens. + const float sum_exp = Sum(d, x, num); + const float prob = x[argmax] / sum_exp; + return TokenAndProb{.token = argmax, .prob = prob}; +} + static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, const size_t size, const size_t max_pos) { diff --git a/util/allocator.h b/util/allocator.h index e8fa41c..38ff84e 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -24,6 +24,12 @@ namespace gcpp { +// Shared between gemma.h and ops-inl.h. +struct TokenAndProb { + int token; + float prob; +}; + using ByteStorageT = hwy::AlignedFreeUniquePtr; template