1.09x decode speedup for topk=1/temp0: fuse softmax and sample

PiperOrigin-RevId: 680589099
This commit is contained in:
Jan Wassenberg 2024-09-30 08:37:01 -07:00 committed by Copybara-Service
parent 897f902d28
commit 2d14d796e3
6 changed files with 171 additions and 42 deletions

View File

@ -285,6 +285,7 @@ cc_library(
deps = [ deps = [
":common", ":common",
":gemma_lib", ":gemma_lib",
":ops",
"@hwy//:hwy", "@hwy//:hwy",
], ],
) )

View File

@ -13,18 +13,31 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "evals/cross_entropy.h" // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "evals/cross_entropy.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "ops/ops-inl.h" // Softmax
#ifndef GEMMA_CROSS_ENTROPY_ONCE
#define GEMMA_CROSS_ENTROPY_ONCE
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm> #include <algorithm> // std::sort
#include <cmath> #include <cmath>
#include <regex> // NOLINT #include <regex> // NOLINT
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "evals/cross_entropy.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -63,10 +76,30 @@ void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
} }
} }
} // namespace } // namespace
} // namespace gcpp
#endif // GEMMA_CROSS_ENTROPY_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
Softmax(logits, vocab_size);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, const std::vector<int>& prompt, KVCache& kv_cache,
KVCache& kv_cache, int verbosity) { int verbosity) {
const StreamFunc stream_token = [](int /*token*/, float) { return true; }; const StreamFunc stream_token = [](int /*token*/, float) { return true; };
// TWeight is unused, but we have to pass it to Config*. // TWeight is unused, but we have to pass it to Config*.
@ -74,8 +107,11 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model); CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
float cross_entropy = std::log(vocab_size); // first token float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1; size_t pos = 1;
const SampleFunc sample_token = [&](const float* probs,
size_t vocab_size) -> int { const SampleFunc sample_token = [&](float* probs,
size_t vocab_size) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
// We are called for each token, but pos starts at 1. Clamping max_tokens // We are called for each token, but pos starts at 1. Clamping max_tokens
// to prompt.size() should prevent overrun. // to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size()); HWY_ASSERT(pos < prompt.size());
@ -96,8 +132,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
cross_entropy / std::log(2.0) / (pos + 1)); cross_entropy / std::log(2.0) / (pos + 1));
} }
++pos; ++pos;
return token; return TokenAndProb{.token = token, .prob = prob};
}; };
std::vector<int> prompt0 = { prompt[0] }; std::vector<int> prompt0 = { prompt[0] };
max_tokens = HWY_MIN(max_tokens, prompt.size()); max_tokens = HWY_MIN(max_tokens, prompt.size());
RuntimeConfig runtime = { RuntimeConfig runtime = {
@ -118,3 +155,4 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
} }
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE

View File

@ -15,20 +15,10 @@
// SIMD functions for Gemma/Griffin transformers. // SIMD functions for Gemma/Griffin transformers.
// Include guard (still compiled once per target)
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#else
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#endif
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::min #include <algorithm> // std::min
#include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
@ -38,9 +28,6 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
// Placeholder for internal test4, do not remove // Placeholder for internal test4, do not remove
#include "ops/matmul-inl.h"
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/threading.h" #include "util/threading.h"
@ -48,10 +35,24 @@
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/bit_set.h" #include "hwy/bit_set.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// Include guard (still compiled once per target)
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#else
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matmul-inl.h"
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#include "hwy/profiler.h" // also uses SIMD
#ifndef GEMMA_CONFIG #ifndef GEMMA_CONFIG
#if HWY_IDE #if HWY_IDE
// Provide a definition so the IDE does not complain. // Provide a definition so the IDE does not complain.
@ -1165,6 +1166,32 @@ class TokenStreamer {
hwy::BitSet4096<> is_eos_; hwy::BitSet4096<> is_eos_;
}; };
template <class TConfig>
SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
constexpr size_t kTopK = TConfig::kTopK;
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;
// Fast path for top-1 with no accept_token.
if (kTopK == 1 && !runtime_config.accept_token) {
PROFILER_ZONE("Gen.Sample Top1");
return [](float* logits, size_t vocab_size) -> TokenAndProb {
return Top1OfSoftmax(logits, vocab_size);
};
}
// General case: Softmax with top-k sampling.
return [&runtime_config](float* logits, size_t vocab_size) -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
Softmax(logits, vocab_size);
const int token = SampleTopK<kTopK>(logits, vocab_size, *runtime_config.gen,
runtime_config.temperature,
runtime_config.accept_token);
return TokenAndProb{.token = token, .prob = logits[token]};
};
}
// Generates one continuation for each query in `queries_prompt`, which is one // Generates one continuation for each query in `queries_prompt`, which is one
// qbatch whose size is at most the `batch_size` passed to // qbatch whose size is at most the `batch_size` passed to
// `activations.Allocate`. // `activations.Allocate`.
@ -1214,18 +1241,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
} }
} }
// If no sample_func is provided, we use top-k sampling. const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
const SampleFunc sample_token =
runtime_config.sample_func
? runtime_config.sample_func
: [&](const float* logits, size_t vocab_size) -> int {
return SampleTopK<TConfig::kTopK>(logits, vocab_size, *runtime_config.gen,
runtime_config.temperature,
runtime_config.accept_token);
};
// Prefill stops before min_prompt_size - 1 because the last prompt token is // Prefill stops before min_prompt_size - 1 because the last prompt
// the first input token for generation. // token is the first input token for generation.
const double prefill_start = hwy::platform::Now(); const double prefill_start = hwy::platform::Now();
// If tbatch is larger than the qbatch we already have in `activations`, then // If tbatch is larger than the qbatch we already have in `activations`, then
// allocate prefill_activations, otherwise reuse. // allocate prefill_activations, otherwise reuse.
@ -1283,15 +1302,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
Softmax(logits, kVocabSize); const TokenAndProb tp = sample_token(logits, kVocabSize);
const int token = sample_token(logits, kVocabSize);
timing_info.NotifyGenerated(prefill_start, gen_start); timing_info.NotifyGenerated(prefill_start, gen_start);
const bool is_eos = const bool is_eos =
token_streamer(query_idx_start + query_idx, token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], token, logits[token]); queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos; all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
} }
if (all_queries_eos) break; if (all_queries_eos) break;
} // foreach token to generate } // foreach token to generate

View File

@ -57,9 +57,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for // If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate. // tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>; using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the probability distribution for the // If not empty, SampleFunc is called with the logits for the next token, which
// next token, and its return value is used as the next generated token. // it may modify/overwrite, and its return value is the next generated token
using SampleFunc = std::function<int(const float*, size_t)>; // together with its probability.
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with: // If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise. // - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence // - position in the tokens sequence

View File

@ -28,10 +28,10 @@
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include "compression/compress.h" #include "compression/compress.h"
#include "util/allocator.h" // TokenAndProb
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h" #include "hwy/detect_targets.h"
#include "hwy/profiler.h"
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.
@ -46,6 +46,7 @@
#include "ops/dot-inl.h" #include "ops/dot-inl.h"
#include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/math/math-inl.h"
#include "hwy/profiler.h" // also uses SIMD
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -602,7 +603,6 @@ HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) { const size_t mask_pos) {
PROFILER_FUNC;
HWY_DASSERT(size != 0); HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size); HWY_DASSERT(mask_pos <= size);
@ -644,6 +644,71 @@ static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
Softmax(x, size, size); Softmax(x, size, size);
} }
// Returns argmax of softmax and its probability. This overwrites `x`, but not
// with normalized probabilities. Only equivalent to `Softmax` + `sample_func`
// if `kTopK` == 1. This is worthwhile because `num` is
// typically `kVocabSize` == 256K, and this avoids writing that many, and then
// scanning them again for the max.
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
const size_t num) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
using M = hn::Mask<D>;
const D d;
const hn::RebindToSigned<D> di;
using TI = hn::TFromD<decltype(di)>;
using VI = hn::Vec<decltype(di)>;
const size_t N = hn::Lanes(d);
HWY_ASSERT(num % (2 * N) == 0);
V max0 = hn::Set(d, hwy::LowestValue<float>());
V max1 = max0;
VI argmax0 = hn::Zero(di);
VI argmax1 = argmax0;
for (size_t i = 0; i < num; i += 2 * N) {
const V v0 = hn::LoadU(d, x + i);
const V v1 = hn::LoadU(d, x + i + N);
const VI vi0 = hn::Iota(di, static_cast<TI>(i));
const VI vi1 = hn::Iota(di, static_cast<TI>(i + N));
const M gt0 = hn::Gt(v0, max0);
const M gt1 = hn::Gt(v1, max1);
max0 = hn::IfThenElse(gt0, v0, max0);
max1 = hn::IfThenElse(gt1, v1, max1);
argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), vi0, argmax0);
argmax1 = hn::IfThenElse(hn::RebindMask(di, gt1), vi1, argmax1);
}
// Combine the two vectors
const M gt0 = hn::Gt(max0, max1);
max0 = hn::IfThenElse(gt0, max0, max1);
argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), argmax0, argmax1);
// Reduce to the global max
const V max = hn::MaxOfLanes(d, max0); // broadcasts
const V* pmax = &max;
// Argmax = lowest-indexed lane equal to the global max
const size_t lane = hn::FindKnownFirstTrue(d, hn::Eq(max, max0));
const TI argmax = hn::ExtractLane(argmax0, lane);
// Subtract max (avoid precision loss for large exponents) and exponentiate.
hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR {
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
return hn::CallExp(d, hn::Sub(value, *pmax));
} else {
return hn::Exp(d, hn::Sub(value, *pmax));
}
});
// Normalize to a single probability. The exact sum seems like it should not
// make a huge difference. It halves the standard deviation of the sum of the
// normalized probabilities from 1E-7 to 5E-8, but actually also changes the
// generated text after a few hundred tokens.
const float sum_exp = Sum(d, x, num);
const float prob = x[argmax] / sum_exp;
return TokenAndProb{.token = argmax, .prob = prob};
}
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const size_t size, const size_t size,
const size_t max_pos) { const size_t max_pos) {

View File

@ -24,6 +24,12 @@
namespace gcpp { namespace gcpp {
// Shared between gemma.h and ops-inl.h.
struct TokenAndProb {
int token;
float prob;
};
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>; using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
template <typename T> template <typename T>