mirror of https://github.com/google/gemma.cpp.git
1.09x decode speedup for topk=1/temp0: fuse softmax and sample
PiperOrigin-RevId: 680589099
This commit is contained in:
parent
897f902d28
commit
2d14d796e3
|
|
@ -285,6 +285,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
":ops",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,18 +13,31 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "evals/cross_entropy.h"
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "evals/cross_entropy.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "ops/ops-inl.h" // Softmax
|
||||||
|
|
||||||
|
#ifndef GEMMA_CROSS_ENTROPY_ONCE
|
||||||
|
#define GEMMA_CROSS_ENTROPY_ONCE
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // std::sort
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <regex> // NOLINT
|
#include <regex> // NOLINT
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "evals/cross_entropy.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -63,10 +76,30 @@ void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // GEMMA_CROSS_ENTROPY_ONCE
|
||||||
|
|
||||||
|
// SIMD code, compiled once per target.
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
|
||||||
|
Softmax(logits, vocab_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#if HWY_ONCE
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
HWY_EXPORT(CallSoftmax);
|
||||||
|
|
||||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
KVCache& kv_cache, int verbosity) {
|
int verbosity) {
|
||||||
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
|
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
|
||||||
|
|
||||||
// TWeight is unused, but we have to pass it to Config*.
|
// TWeight is unused, but we have to pass it to Config*.
|
||||||
|
|
@ -74,8 +107,11 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
|
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
|
||||||
float cross_entropy = std::log(vocab_size); // first token
|
float cross_entropy = std::log(vocab_size); // first token
|
||||||
size_t pos = 1;
|
size_t pos = 1;
|
||||||
const SampleFunc sample_token = [&](const float* probs,
|
|
||||||
size_t vocab_size) -> int {
|
const SampleFunc sample_token = [&](float* probs,
|
||||||
|
size_t vocab_size) -> TokenAndProb {
|
||||||
|
// input is logits, not yet probabilities
|
||||||
|
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
||||||
// We are called for each token, but pos starts at 1. Clamping max_tokens
|
// We are called for each token, but pos starts at 1. Clamping max_tokens
|
||||||
// to prompt.size() should prevent overrun.
|
// to prompt.size() should prevent overrun.
|
||||||
HWY_ASSERT(pos < prompt.size());
|
HWY_ASSERT(pos < prompt.size());
|
||||||
|
|
@ -96,8 +132,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
cross_entropy / std::log(2.0) / (pos + 1));
|
cross_entropy / std::log(2.0) / (pos + 1));
|
||||||
}
|
}
|
||||||
++pos;
|
++pos;
|
||||||
return token;
|
return TokenAndProb{.token = token, .prob = prob};
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<int> prompt0 = { prompt[0] };
|
std::vector<int> prompt0 = { prompt[0] };
|
||||||
max_tokens = HWY_MIN(max_tokens, prompt.size());
|
max_tokens = HWY_MIN(max_tokens, prompt.size());
|
||||||
RuntimeConfig runtime = {
|
RuntimeConfig runtime = {
|
||||||
|
|
@ -118,3 +155,4 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -15,20 +15,10 @@
|
||||||
|
|
||||||
// SIMD functions for Gemma/Griffin transformers.
|
// SIMD functions for Gemma/Griffin transformers.
|
||||||
|
|
||||||
// Include guard (still compiled once per target)
|
|
||||||
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \
|
|
||||||
defined(HWY_TARGET_TOGGLE)
|
|
||||||
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
|
||||||
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
|
||||||
#else
|
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm> // std::min
|
#include <algorithm> // std::min
|
||||||
#include <string>
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -38,9 +28,6 @@
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
// Placeholder for internal test4, do not remove
|
// Placeholder for internal test4, do not remove
|
||||||
#include "ops/matmul-inl.h"
|
|
||||||
#include "ops/matvec-inl.h"
|
|
||||||
#include "ops/ops-inl.h"
|
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
|
@ -48,10 +35,24 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/bit_set.h"
|
#include "hwy/bit_set.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
|
||||||
#include "hwy/profiler.h"
|
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
// Include guard (still compiled once per target)
|
||||||
|
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \
|
||||||
|
defined(HWY_TARGET_TOGGLE)
|
||||||
|
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
||||||
|
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
||||||
|
#else
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "ops/matvec-inl.h"
|
||||||
|
#include "ops/ops-inl.h"
|
||||||
|
#include "hwy/profiler.h" // also uses SIMD
|
||||||
|
|
||||||
#ifndef GEMMA_CONFIG
|
#ifndef GEMMA_CONFIG
|
||||||
#if HWY_IDE
|
#if HWY_IDE
|
||||||
// Provide a definition so the IDE does not complain.
|
// Provide a definition so the IDE does not complain.
|
||||||
|
|
@ -1165,6 +1166,32 @@ class TokenStreamer {
|
||||||
hwy::BitSet4096<> is_eos_;
|
hwy::BitSet4096<> is_eos_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
|
constexpr size_t kTopK = TConfig::kTopK;
|
||||||
|
|
||||||
|
// If user provided a sample_func, use it.
|
||||||
|
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||||
|
|
||||||
|
// Fast path for top-1 with no accept_token.
|
||||||
|
if (kTopK == 1 && !runtime_config.accept_token) {
|
||||||
|
PROFILER_ZONE("Gen.Sample Top1");
|
||||||
|
return [](float* logits, size_t vocab_size) -> TokenAndProb {
|
||||||
|
return Top1OfSoftmax(logits, vocab_size);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// General case: Softmax with top-k sampling.
|
||||||
|
return [&runtime_config](float* logits, size_t vocab_size) -> TokenAndProb {
|
||||||
|
PROFILER_ZONE("Gen.Sample general");
|
||||||
|
Softmax(logits, vocab_size);
|
||||||
|
const int token = SampleTopK<kTopK>(logits, vocab_size, *runtime_config.gen,
|
||||||
|
runtime_config.temperature,
|
||||||
|
runtime_config.accept_token);
|
||||||
|
return TokenAndProb{.token = token, .prob = logits[token]};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Generates one continuation for each query in `queries_prompt`, which is one
|
// Generates one continuation for each query in `queries_prompt`, which is one
|
||||||
// qbatch whose size is at most the `batch_size` passed to
|
// qbatch whose size is at most the `batch_size` passed to
|
||||||
// `activations.Allocate`.
|
// `activations.Allocate`.
|
||||||
|
|
@ -1214,18 +1241,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no sample_func is provided, we use top-k sampling.
|
const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
|
||||||
const SampleFunc sample_token =
|
|
||||||
runtime_config.sample_func
|
|
||||||
? runtime_config.sample_func
|
|
||||||
: [&](const float* logits, size_t vocab_size) -> int {
|
|
||||||
return SampleTopK<TConfig::kTopK>(logits, vocab_size, *runtime_config.gen,
|
|
||||||
runtime_config.temperature,
|
|
||||||
runtime_config.accept_token);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||||
// the first input token for generation.
|
// token is the first input token for generation.
|
||||||
const double prefill_start = hwy::platform::Now();
|
const double prefill_start = hwy::platform::Now();
|
||||||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||||
// allocate prefill_activations, otherwise reuse.
|
// allocate prefill_activations, otherwise reuse.
|
||||||
|
|
@ -1283,15 +1302,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||||
Softmax(logits, kVocabSize);
|
const TokenAndProb tp = sample_token(logits, kVocabSize);
|
||||||
const int token = sample_token(logits, kVocabSize);
|
|
||||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||||
|
|
||||||
const bool is_eos =
|
const bool is_eos =
|
||||||
token_streamer(query_idx_start + query_idx,
|
token_streamer(query_idx_start + query_idx,
|
||||||
queries_mutable_pos[query_idx], token, logits[token]);
|
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
||||||
all_queries_eos &= is_eos;
|
all_queries_eos &= is_eos;
|
||||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
|
||||||
}
|
}
|
||||||
if (all_queries_eos) break;
|
if (all_queries_eos) break;
|
||||||
} // foreach token to generate
|
} // foreach token to generate
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||||
// If not empty, AcceptFunc is called with token. It should return false for
|
// If not empty, AcceptFunc is called with token. It should return false for
|
||||||
// tokens you don't want to generate and true for tokens you want to generate.
|
// tokens you don't want to generate and true for tokens you want to generate.
|
||||||
using AcceptFunc = std::function<bool(int, float)>;
|
using AcceptFunc = std::function<bool(int, float)>;
|
||||||
// If not empty, SampleFunc is called with the probability distribution for the
|
// If not empty, SampleFunc is called with the logits for the next token, which
|
||||||
// next token, and its return value is used as the next generated token.
|
// it may modify/overwrite, and its return value is the next generated token
|
||||||
using SampleFunc = std::function<int(const float*, size_t)>;
|
// together with its probability.
|
||||||
|
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
||||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||||
// - index of query within containing batch (if any); zero otherwise.
|
// - index of query within containing batch (if any); zero otherwise.
|
||||||
// - position in the tokens sequence
|
// - position in the tokens sequence
|
||||||
|
|
|
||||||
|
|
@ -28,10 +28,10 @@
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
#include "util/allocator.h" // TokenAndProb
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
#include "hwy/profiler.h"
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
|
|
@ -46,6 +46,7 @@
|
||||||
#include "ops/dot-inl.h"
|
#include "ops/dot-inl.h"
|
||||||
#include "hwy/contrib/algo/transform-inl.h"
|
#include "hwy/contrib/algo/transform-inl.h"
|
||||||
#include "hwy/contrib/math/math-inl.h"
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
|
#include "hwy/profiler.h" // also uses SIMD
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -602,7 +603,6 @@ HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
|
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos) {
|
const size_t mask_pos) {
|
||||||
PROFILER_FUNC;
|
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
|
||||||
|
|
@ -644,6 +644,71 @@ static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||||
Softmax(x, size, size);
|
Softmax(x, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns argmax of softmax and its probability. This overwrites `x`, but not
|
||||||
|
// with normalized probabilities. Only equivalent to `Softmax` + `sample_func`
|
||||||
|
// if `kTopK` == 1. This is worthwhile because `num` is
|
||||||
|
// typically `kVocabSize` == 256K, and this avoids writing that many, and then
|
||||||
|
// scanning them again for the max.
|
||||||
|
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||||
|
const size_t num) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
using M = hn::Mask<D>;
|
||||||
|
const D d;
|
||||||
|
const hn::RebindToSigned<D> di;
|
||||||
|
using TI = hn::TFromD<decltype(di)>;
|
||||||
|
using VI = hn::Vec<decltype(di)>;
|
||||||
|
const size_t N = hn::Lanes(d);
|
||||||
|
HWY_ASSERT(num % (2 * N) == 0);
|
||||||
|
|
||||||
|
V max0 = hn::Set(d, hwy::LowestValue<float>());
|
||||||
|
V max1 = max0;
|
||||||
|
VI argmax0 = hn::Zero(di);
|
||||||
|
VI argmax1 = argmax0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num; i += 2 * N) {
|
||||||
|
const V v0 = hn::LoadU(d, x + i);
|
||||||
|
const V v1 = hn::LoadU(d, x + i + N);
|
||||||
|
const VI vi0 = hn::Iota(di, static_cast<TI>(i));
|
||||||
|
const VI vi1 = hn::Iota(di, static_cast<TI>(i + N));
|
||||||
|
const M gt0 = hn::Gt(v0, max0);
|
||||||
|
const M gt1 = hn::Gt(v1, max1);
|
||||||
|
max0 = hn::IfThenElse(gt0, v0, max0);
|
||||||
|
max1 = hn::IfThenElse(gt1, v1, max1);
|
||||||
|
argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), vi0, argmax0);
|
||||||
|
argmax1 = hn::IfThenElse(hn::RebindMask(di, gt1), vi1, argmax1);
|
||||||
|
}
|
||||||
|
// Combine the two vectors
|
||||||
|
const M gt0 = hn::Gt(max0, max1);
|
||||||
|
max0 = hn::IfThenElse(gt0, max0, max1);
|
||||||
|
argmax0 = hn::IfThenElse(hn::RebindMask(di, gt0), argmax0, argmax1);
|
||||||
|
// Reduce to the global max
|
||||||
|
const V max = hn::MaxOfLanes(d, max0); // broadcasts
|
||||||
|
const V* pmax = &max;
|
||||||
|
// Argmax = lowest-indexed lane equal to the global max
|
||||||
|
const size_t lane = hn::FindKnownFirstTrue(d, hn::Eq(max, max0));
|
||||||
|
const TI argmax = hn::ExtractLane(argmax0, lane);
|
||||||
|
|
||||||
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
|
hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
|
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||||
|
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||||
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
|
} else {
|
||||||
|
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Normalize to a single probability. The exact sum seems like it should not
|
||||||
|
// make a huge difference. It halves the standard deviation of the sum of the
|
||||||
|
// normalized probabilities from 1E-7 to 5E-8, but actually also changes the
|
||||||
|
// generated text after a few hundred tokens.
|
||||||
|
const float sum_exp = Sum(d, x, num);
|
||||||
|
const float prob = x[argmax] / sum_exp;
|
||||||
|
return TokenAndProb{.token = argmax, .prob = prob};
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
const size_t size,
|
const size_t size,
|
||||||
const size_t max_pos) {
|
const size_t max_pos) {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,12 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Shared between gemma.h and ops-inl.h.
|
||||||
|
struct TokenAndProb {
|
||||||
|
int token;
|
||||||
|
float prob;
|
||||||
|
};
|
||||||
|
|
||||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue