mirror of https://github.com/google/gemma.cpp.git
-467ms startup: parallel Reshape
Also split Softmax into Argmax helper, add comments; add profiler zones + fix IDE warning PiperOrigin-RevId: 680954573
This commit is contained in:
parent
d83ad76679
commit
7d9fcda0d8
|
|
@ -1292,13 +1292,16 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// queries_pos are incremented by Transformer.
|
||||
|
||||
bool all_queries_eos = true;
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_queries, ConstMat(activations.x.All(), kModelDim),
|
||||
ConstMat(weights.embedder_input_embedding.data(), kModelDim),
|
||||
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
|
||||
activations.env, MutableMat(activations.logits.All(), kVocabSize));
|
||||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_queries, ConstMat(activations.x.All(), kModelDim),
|
||||
ConstMat(weights.embedder_input_embedding.data(), kModelDim),
|
||||
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
|
||||
activations.env, MutableMat(activations.logits.All(), kVocabSize));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ struct LoadCompressedWeightsT {
|
|||
}
|
||||
{
|
||||
PROFILER_ZONE("Startup.Reshape");
|
||||
c_weights->Reshape();
|
||||
c_weights->Reshape(pool);
|
||||
}
|
||||
return c_weights_u8;
|
||||
}
|
||||
|
|
@ -90,7 +90,8 @@ ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
|||
}
|
||||
|
||||
namespace {
|
||||
void LogVec(const char* name, const float* data, size_t len) {
|
||||
// For reasons unknown, this is shown as potentially unused in the IDE.
|
||||
void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) {
|
||||
hwy::Stats stats;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
stats.Notify(data[i]);
|
||||
|
|
|
|||
|
|
@ -210,10 +210,10 @@ struct CompressedWeights {
|
|||
explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {}
|
||||
|
||||
// Called by weights.cc after ForEachTensor.
|
||||
void Reshape() {
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
void Reshape(hwy::ThreadPool& pool) {
|
||||
pool.Run(0, TConfig::kLayers, [this](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Reshape();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ZeroInit() {
|
||||
|
|
@ -279,7 +279,7 @@ struct ReshapeCompressedWeights {
|
|||
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
||||
CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
weights.Reshape();
|
||||
weights.Reshape(pool);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -601,6 +601,7 @@ HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
|
|||
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault<VT>());
|
||||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t mask_pos) {
|
||||
HWY_DASSERT(size != 0);
|
||||
|
|
@ -644,13 +645,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
|||
Softmax(x, size, size);
|
||||
}
|
||||
|
||||
// Returns argmax of softmax and its probability. This overwrites `x`, but not
|
||||
// with normalized probabilities. Only equivalent to `Softmax` + `sample_func`
|
||||
// if `kTopK` == 1. This is worthwhile because `num` is
|
||||
// typically `kVocabSize` == 256K, and this avoids writing that many, and then
|
||||
// scanning them again for the max.
|
||||
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||
const size_t num) {
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
// exp / mul passes with two passes, both of which compute Exp. This is
|
||||
// reportedly only faster for very large arrays, larger even than our 256K
|
||||
// vocab size. We instead fuse the subsequent sampling pass into the softmax,
|
||||
// which already knows the max value which top-1 sampling would again seek.
|
||||
|
||||
// Returns the argmax and x[argmax].
|
||||
static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
||||
const size_t num) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
|
|
@ -679,18 +682,36 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
|||
argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), vi0, argmax0);
|
||||
argmax1 = hn::IfThenElse(hn::RebindMask(di, gt1), vi1, argmax1);
|
||||
}
|
||||
|
||||
// Combine the two vectors
|
||||
const M gt0 = hn::Gt(max0, max1);
|
||||
max0 = hn::IfThenElse(gt0, max0, max1);
|
||||
argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), argmax0, argmax1);
|
||||
|
||||
// Reduce to the global max
|
||||
const V max = hn::MaxOfLanes(d, max0); // broadcasts
|
||||
const V* pmax = &max;
|
||||
|
||||
// Argmax = lowest-indexed lane equal to the global max
|
||||
const size_t lane = hn::FindKnownFirstTrue(d, hn::Eq(max, max0));
|
||||
const TI argmax = hn::ExtractLane(argmax0, lane);
|
||||
return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)};
|
||||
}
|
||||
|
||||
// Returns argmax of softmax and its probability. This overwrites `x`, but not
|
||||
// with normalized probabilities. Only equivalent to `Softmax` + `sample_func`
|
||||
// if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize`
|
||||
// == 256K, and this avoids writing and then scanning again for the max.
|
||||
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||
const size_t num) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> d;
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
|
||||
const TokenAndProb argmax = ArgmaxAndMax(x, num);
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
const V max = hn::Set(d, argmax.prob);
|
||||
const V* pmax = &max;
|
||||
hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||
|
|
@ -705,8 +726,8 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
|||
// normalized probabilities from 1E-7 to 5E-8, but actually also changes the
|
||||
// generated text after a few hundred tokens.
|
||||
const float sum_exp = Sum(d, x, num);
|
||||
const float prob = x[argmax] / sum_exp;
|
||||
return TokenAndProb{.token = argmax, .prob = prob};
|
||||
const float prob = x[argmax.token] / sum_exp;
|
||||
return TokenAndProb{.token = argmax.token, .prob = prob};
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
|
|
|
|||
Loading…
Reference in New Issue