diff --git a/BUILD.bazel b/BUILD.bazel index 10b1e42..1bd177e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -99,6 +99,7 @@ cc_library( srcs = ["gemma/tokenizer.cc"], hdrs = ["gemma/tokenizer.h"], deps = [ + ":common", "//compression:io", "@hwy//:hwy", "@hwy//:nanobenchmark", # timer @@ -121,8 +122,27 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/gemma.cc", + "gemma/instantiations/27b_bf16.cc", + "gemma/instantiations/27b_f32.cc", + "gemma/instantiations/27b_sfp.cc", + "gemma/instantiations/2b_bf16.cc", + "gemma/instantiations/2b_f32.cc", + "gemma/instantiations/2b_sfp.cc", + "gemma/instantiations/7b_bf16.cc", + "gemma/instantiations/7b_f32.cc", + "gemma/instantiations/7b_sfp.cc", + "gemma/instantiations/9b_bf16.cc", + "gemma/instantiations/9b_f32.cc", + "gemma/instantiations/9b_sfp.cc", + "gemma/instantiations/tiny_bf16.cc", + "gemma/instantiations/tiny_f32.cc", + "gemma/instantiations/tiny_sfp.cc", + "gemma/instantiations/gr2b_bf16.cc", + "gemma/instantiations/gr2b_f32.cc", + "gemma/instantiations/gr2b_sfp.cc", ], hdrs = [ + "gemma/activations.h", "gemma/gemma.h", ], exec_properties = { @@ -130,7 +150,7 @@ cc_library( "mem": "28g", }, textual_hdrs = [ - # Placeholder for internal file1, do not remove, + "gemma/gemma-inl.h", # Placeholder for internal file2, do not remove, ], deps = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 7da189b..8b4fc49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,13 +63,33 @@ set(SOURCES backprop/optimizer.h evals/cross_entropy.cc evals/cross_entropy.h + gemma/activations.h gemma/benchmark_helper.cc gemma/benchmark_helper.h gemma/common.cc gemma/common.h gemma/configs.h + gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h + gemma/instantiations/27b_bf16.cc + gemma/instantiations/27b_f32.cc + gemma/instantiations/27b_sfp.cc + gemma/instantiations/2b_bf16.cc + gemma/instantiations/2b_f32.cc + gemma/instantiations/2b_sfp.cc + gemma/instantiations/7b_bf16.cc + gemma/instantiations/7b_f32.cc + gemma/instantiations/7b_sfp.cc + gemma/instantiations/9b_bf16.cc + gemma/instantiations/9b_f32.cc + gemma/instantiations/9b_sfp.cc + gemma/instantiations/gr2b_bf16.cc + gemma/instantiations/gr2b_f32.cc + gemma/instantiations/gr2b_sfp.cc + gemma/instantiations/tiny_bf16.cc + gemma/instantiations/tiny_f32.cc + gemma/instantiations/tiny_sfp.cc gemma/kv_cache.cc gemma/kv_cache.h gemma/ops.h diff --git a/gemma/activations.h b/gemma/activations.h new file mode 100644 index 0000000..ffaa726 --- /dev/null +++ b/gemma/activations.h @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ + +#include + +#include + +#include "gemma/common.h" // AllocateSizeof +#include "hwy/base.h" // hwy::bfloat16_t + +namespace gcpp { + +// Must be aligned. +template +struct Activations { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention + // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, + // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. + static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1); + + std::array x; // input + std::array pre_att_rms_out; + std::array q; // query vector + std::array + att; // attention vector + std::array att_out; // attention output + std::array + att_post1; // attention output after linear transformation, per head + std::array + att_post2; // accumulation of attention outputs over heads + std::array bf_pre_ffw_rms_out; + std::array ffw_hidden; + + // For FFW MatMul. + std::array C1; + std::array C2; + + std::array ffw_out; + std::array logits; + + // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into + // per-thread storage. + std::array even_odd; + + // Griffin layer internal activations + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; + std::array griffin_x; + std::array griffin_y; + std::array griffin_gate_x; + std::array griffin_multiplier; +}; + +template +struct AllocateState { + void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { + // When batching queries, the prefill batch size is reduced by a factor + // of kBatchedQueryBatchSize + prefill = + AllocateSizeof>(); + decode = AllocateSizeof< + Activations>(); + } +}; + +template +Activations& GetActivations(const ByteStorageT& state_u8) { + return *reinterpret_cast*>(state_u8.get()); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ diff --git a/gemma/common.cc b/gemma/common.cc index fd10809..7fa60c9 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -53,7 +53,8 @@ constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_IT, // Gemma Tiny }; -constexpr size_t kNumModelFlags = std::end(kModelFlags) - std::begin(kModelFlags); +constexpr size_t kNumModelFlags = + std::end(kModelFlags) - std::begin(kModelFlags); static_assert(kNumModelFlags == std::end(kModelTypes) - std::begin(kModelTypes)); static_assert(kNumModelFlags == @@ -123,4 +124,15 @@ const char* ParseType(const std::string& type_string, Type& type) { return kErrorMessageBuffer; } +void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { + + // Instruction-tuned models are trained to expect control tokens. + if (info.training == ModelTraining::GEMMA_IT) { + // Prepend "" if this is a multi-turn dialogue continuation. + const std::string start = (pos == 0) + ? "user\n" + : "\nuser\n"; + prompt = start + prompt + "\nmodel\n"; + } +} } // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h index f4ed10d..151c6b0 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -42,7 +42,8 @@ constexpr size_t kBatchedQueryBatchSize = 16; constexpr size_t kMinAdjustedPrefillBatchSize = HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize); -// Model variants: see configs.h for details. +// Model variants: see configs.h for details. When adding a new one, also +// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. enum class Model { GEMMA_2B, GEMMA_7B, @@ -55,16 +56,29 @@ enum class Model { // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -// Tensor types for loading weights. +// Tensor types for loading weights. When adding a new one, also +// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. enum class Type { kF32, kBF16, kSFP }; -// TODO(janwas): merge with parser/ToString. +// TODO(janwas): merge with functions below. struct ModelInfo { Model model; ModelTraining training; Type weight; }; +// Returns error string or nullptr if OK. +// Thread-hostile. +const char* ParseModelTypeAndTraining(const std::string& model_flag, + Model& model, ModelTraining& training); +const char* ParseType(const std::string& type_string, Type& type); + +// Inverse of ParseModelTypeAndTraining. +const char* ModelString(Model model, ModelTraining training); +const char* StringFromType(Type type); + +void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); + // Returns the return value of FuncT>().operator()(args), where // Config* is selected via `model`. Typically called by CallForModelAndWeight, // but can also be called directly when FuncT does not actually use TWeight. @@ -122,6 +136,20 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, } } +#define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \ + X(CONFIGT, float) \ + X(CONFIGT, hwy::bfloat16_t) \ + X(CONFIGT, SfpStream) + +#define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \ + static_assert(true, "Allow trailing ;") + // Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float), // calls FUNC> where ConfigT is chosen via MODEL enum. #define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \ @@ -163,6 +191,8 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, // Like CallForModelAndWeight, but for SIMD function templates. This is a macro // because it boils down to N_SSE4::FUNC, which would not work if FUNC was a // normal function argument. MODEL and WEIGHT are enums. +// For gemma.cc, we use overloaded extern functions for faster builds. However, +// this is still used in compress_weights because its compile time is OK. #define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \ switch (WEIGHT) { \ case Type::kF32: \ @@ -178,16 +208,6 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, HWY_ABORT("Weight type %d unknown.", static_cast(WEIGHT)); \ } -// Returns error string or nullptr if OK. -// Thread-hostile. -const char* ParseModelTypeAndTraining(const std::string& model_flag, - Model& model, ModelTraining& training); -const char* ParseType(const std::string& type_string, Type& type); - -// Inverse of ParseModelTypeAndTraining. -const char* ModelString(Model model, ModelTraining training); -const char* StringFromType(Type type); - // ---------------------------------------------------------------------------- // diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h new file mode 100644 index 0000000..247d953 --- /dev/null +++ b/gemma/gemma-inl.h @@ -0,0 +1,906 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// SIMD functions for Gemma/Griffin transformers. + +// Include guard (still compiled once per target) +#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#else +#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#endif + +#include +#include +#include // memcpy + +#include +#include +#include + +#include "gemma/activations.h" +#include "gemma/common.h" +#include "gemma/gemma.h" +#include "gemma/ops.h" +#include "gemma/weights.h" +// Placeholder for internal test4, do not remove +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/matvec/matvec-inl.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +#ifndef GEMMA_CONFIG +#if HWY_IDE +// Provide a definition so the IDE does not complain. +#define GEMMA_CONFIG ConfigGemmaTiny +#else +#error "Only include from instantiations/*.cc, which must define GEMMA_CONFIG" +#endif // HWY_IDE +#endif // GEMMA_CONFIG + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +template +HWY_NOINLINE void GriffinRecurrent( + size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, hwy::ThreadPool& pool) { + PROFILER_ZONE("Gen.Griffin"); + static_assert(kQueryBatchSize == 1, + "Griffin does not support batched queries."); + HWY_DASSERT(num_queries == 1); // TODO: add batch query support for Griffin. + KVCache& kv_cache = *kv_caches[0]; + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + HWY_DASSERT(num_tokens <= kBatchSize); + static constexpr size_t kModelDim = + gcpp::Activations::kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr size_t kHeads = TConfig::kHeads; + + // X / Y linear layers. + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + TwoMatVecAdd( + layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, + activations.pre_att_rms_out.data() + batch_offset, + /*add0=*/layer_weights->griffin.linear_x_biases.data(), + /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, + /*out1=*/y, pool); + Gelu(y, kModelDim); + } + + // Conv1D. + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + HWY_FULL(float) df; + HWY_DASSERT(kModelDim % hn::Lanes(df) == 0); + const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); + + // cache[i] = input at time t-i. + float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)]; + cache[0] = x; + for (size_t i = 1; i < kConv1dWidth; i++) { + cache[i] = + kv_cache.conv1d_cache.get() + layer_offset + + ((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim; + } + for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) { + auto xv = hn::Load(df, x + i); + auto accum0 = + hn::Load(df, layer_weights->griffin.conv_biases.data() + i); + auto accum1 = hn::Zero(df); + static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); + for (size_t l = 0; 2 * l < kConv1dWidth; l++) { + auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + + (kConv1dWidth - 1 - 2 * l) * kModelDim + i); + auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + + (kConv1dWidth - 2 - 2 * l) * kModelDim + i); + accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); + accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); + } + hn::Store(hn::Add(accum0, accum1), df, x + i); + hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i); + } + } + + // RGLRU + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + float* HWY_RESTRICT gate_x = + activations.griffin_gate_x.data() + batch_offset; + float* HWY_RESTRICT a = + activations.griffin_multiplier.data() + batch_offset; + float* HWY_RESTRICT rnn_state = + kv_cache.rglru_cache.get() + layer * kModelDim; + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + constexpr size_t kHeadDim = kModelDim / kHeads; + constexpr size_t kMatrixSize = kHeadDim * kHeadDim; + size_t head_offset = head * kHeadDim; + TwoOfsMatVecAddLoop( + layer_weights->griffin.gate_w, kMatrixSize * head, + kMatrixSize * (kHeads + head), x + head_offset, + /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + + head_offset, + /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + Sigmoid(gate_x + head_offset, kHeadDim); + Sigmoid(a + head_offset, kHeadDim); + const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) + HWY_ATTR { return hn::Mul(x, gate_x); }; + hn::Transform1(D(), a + head_offset, kHeadDim, + layer_weights->griffin.a.data() + head_offset, fn_mul); + hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, + fn_mul); + // RNN scan + HWY_FULL(float) df; + HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); + for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { + auto log_a = hn::Load(df, a + head_offset + i); + auto gated_x = hn::Load(df, x + head_offset + i); + auto rnn = hn::Load(df, rnn_state + head_offset + i); + auto a = hn::Exp(df, log_a); + auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); + if (pos == 0) { + x_multiplier = hn::Set(df, 1.0f); + } + auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); + hn::Store(new_x, df, rnn_state + head_offset + i); + + // Join branches. + auto yv = hn::Load(df, y + head_offset + i); + auto pre_out = hn::Mul(yv, new_x); + hn::Store(pre_out, df, x + head_offset + i); + } + }); + } + + // Final linear layer. + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; + MatVecAdd( + layer_weights->griffin.linear_out_w, 0, x, + layer_weights->griffin.linear_out_biases.data(), + activations.even_odd.data(), out_ptr, pool); + } +} + +template +HWY_NOINLINE void Attention( + size_t batch_and_query_start, size_t num_tokens, size_t num_queries, + size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, + hwy::ThreadPool& pool) { + PROFILER_ZONE("Gen.Attention"); + HWY_DASSERT(num_tokens <= kBatchSize); + HWY_DASSERT(num_queries <= kQueryBatchSize); + HWY_DASSERT(batch_and_query_start % num_queries == 0); + using TActivations = Activations; + constexpr size_t kQKVDim = TActivations::kQKVDim; + constexpr size_t kQStride = TActivations::kQStride; + constexpr size_t kCachePosSize = CachePosSize()(); + constexpr size_t kCacheLayerSize = CacheLayerSize()(); + constexpr size_t kModelDim = TActivations::kModelDim; + constexpr size_t kHeads = TConfig::kHeads; + constexpr size_t kKVHeads = TConfig::kKVHeads; + constexpr size_t kSeqLen = TConfig::kSeqLen; + GEMMA_CONSTEXPR_SQRT const float kQueryScale = + 1.0f / Sqrt(static_cast(kQKVDim)); + constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention + const size_t batch_start = batch_and_query_start / num_queries; + const size_t num_tokens_and_queries = num_tokens * num_queries; + + // If MHA, this also computes KV, which we copy to the KV cache below. + static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved + MatMul_4x4_Batch( + num_tokens_and_queries, activations.pre_att_rms_out.data(), + layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); + + for (size_t batch_and_query_idx = 0; + batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { + const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx + * kModelDim; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; + KVCache& kv_cache = *kv_caches[query_idx]; + // QKV projections: + if constexpr (!kIsMHA) { + const size_t pos = batch_start + batch_idx; + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + // TODO: requires MatMul support for offsets. + MatVec( + layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, + activations.even_odd.data(), kv, pool); + } + } + + // Positional encodings for kv: + pool.Run( + 0, kKVHeads * num_tokens_and_queries, + [&](uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kKVHeads; + const size_t batch_and_query_idx = task / kKVHeads; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; + const size_t pos = batch_start + batch_idx; + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head * kQKVDim * 2; + KVCache& kv_cache = *kv_caches[query_idx]; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + if constexpr (kIsMHA) { + // For MHA, copy kv into the KV cache from scratch space (see above). + const float* HWY_RESTRICT q = + activations.q.data() + (batch_and_query_idx * kHeads + + head) * kQStride; + // Skip past the Q part of `q`, and copy KV to `kv`. + memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); + } + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + }); + + static_assert((kHeads % kKVHeads) == 0, + "query heads must be a multiple of key-value heads"); + constexpr size_t kGroupHeads = kHeads / kKVHeads; + pool.Run(0, kHeads * num_tokens_and_queries, + [&](uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t batch_and_query_idx = task / kHeads; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; + const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; + KVCache& kv_cache = *kv_caches[query_idx]; + float* HWY_RESTRICT q = + activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; + + const size_t pos = batch_start + batch_idx; + // Calculate scores + float* HWY_RESTRICT head_att = + activations.att.data() + head * kSeqLen + + batch_and_query_idx * kHeads * kSeqLen; + + Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + MulByConst(kQueryScale, q, kQKVDim); + + // Compute Q dot K scores + const size_t start_pos = + pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; + const float score = Dot(q, k2, kQKVDim); + head_att[pos2 % kSeqLen] = score; + } + const size_t head_att_len = std::min(pos + 1, kSeqLen); + if constexpr (TConfig::kAttCap > 0.0f) { + LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); + } + Softmax(head_att, head_att_len); + + // Weighted summation + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + + batch_and_query_idx * kHeads * kQKVDim; + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; + MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); + } + }); + + for (size_t batch_and_query_idx = 0; + batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { + // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after + // rearranging the weights. + float* HWY_RESTRICT att_out = + activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim; + float* HWY_RESTRICT layer_out = + activations.att_post2.data() + batch_and_query_idx * kModelDim; + MatVecT( + layer_weights->attn_vec_einsum_w, 0, att_out, + layer_weights->attention_output_biases.data(), + activations.even_odd.data(), layer_out, pool); + for (size_t head = 1; head < kHeads; ++head) { + // TODO(patrickms): Check this calculation + float* HWY_RESTRICT head_out = + activations.att_post1.data() + + head * kBatchSize * kQueryBatchSize * kModelDim; + // TODO: requires MatMul support for offsets. + MatVec( + layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, + att_out + head * kQKVDim, + activations.even_odd.data(), head_out, pool); + AddFrom(head_out, layer_out, kModelDim); + } + } +} + +template +HWY_NOINLINE void FFW(Activations& activations, + size_t num_tokens, + const CompressedLayer* layer_weights, + hwy::ThreadPool& pool) { + HWY_DASSERT(num_tokens <= kBatchSize); + constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + float* HWY_RESTRICT even_odd = activations.even_odd.data(); + + // TODO: MatMul does not yet support adding another matrix to the result. + if constexpr (!TConfig::kFFBiases) { + PROFILER_ZONE("Gen.FFW.GatedGELU"); + + // MatMul expects col-major B, which is what we have: kModelDim consecutive + // elements in memory, repeated kFFHiddenDim times. + const auto b1 = layer_weights->gating_einsum_w.data(); + constexpr size_t kColsA = kModelDim; + constexpr size_t kColsB = kFFHiddenDim; + const auto b2 = b1 + kColsA * kColsB; + auto A = activations.bf_pre_ffw_rms_out.data(); + // Will go through GELU. + MatMul_4x4_Batch(num_tokens, A, b1, activations.C1.data(), + pool); + // What to multiply by. + MatMul_4x4_Batch(num_tokens, A, b2, activations.C2.data(), + pool); + + // Gelu and multiply by gate. + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, + activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { + return hn::Mul(mul, Gelu(df, v)); + }); + + MatMul_4x4_Batch(num_tokens, activations.C1.data(), + layer_weights->linear_w.data(), + activations.ffw_out.data(), pool); + } else { // TConfig::kFFBiases == true + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; + const hwy::bfloat16_t* HWY_RESTRICT vec = + activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; + float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; + float* HWY_RESTRICT out_mul = out + kFFHiddenDim; + + PROFILER_ZONE("Gen.FFW.GatedGELU"); + // Same matrix, first and second half of rows. Could fuse into one MatVec. + MatVecT( + layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, + layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, + out_mul, pool); + // Gate, will go through the nonlinearity. + MatVecT( + layer_weights->gating_einsum_w, 0, vec, + layer_weights->ffw_gating_biases.data(), even_odd, out, pool); + + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + hn::Transform1(DF(), out, kFFHiddenDim, out_mul, + [](DF df, VF v, VF mul) + HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); + + MatVecT( + layer_weights->linear_w, 0, + activations.ffw_hidden.data() + hidden_offset, + layer_weights->ffw_output_biases.data(), even_odd, + activations.ffw_out.data() + batch_idx * kModelDim, pool); + } + } +} + +template +HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, + const CompressedWeights& weights, + Activations& activations) { + constexpr size_t kModelDim = TConfig::kModelDim; + GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = + EmbeddingScaling(); + HWY_DASSERT(token >= 0); + HWY_DASSERT(token < TConfig::kVocabSize); + Decompress(weights.embedder_input_embedding, token * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, + kModelDim); + if constexpr (TConfig::kAbsolutePE) { + AddAbsolutePositionalEmbeddings( + activations.x.data() + token_idx * kModelDim, kModelDim, + pos + token_idx); + }; +} + +template +HWY_NOINLINE void TransformerLayer( + size_t num_tokens, size_t num_queries, size_t pos, size_t layer, + const CompressedLayer* layer_weights, + Activations& activations, + const std::vector& kv_caches, hwy::ThreadPool& pool) { + constexpr size_t kModelDim = TConfig::kModelDim; + const size_t num_tokens_and_queries = num_tokens * num_queries; + auto type = TConfig::kLayerConfig[layer]; + size_t layer_of_type = + NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); + RMSNormBatched( + num_tokens_and_queries, activations.x.data(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), kModelDim); + if (type == LayerAttentionType::kGemma) { + Attention( + pos, num_tokens, num_queries, layer_of_type, activations, + layer_weights, kv_caches, pool); + } else { + // This Griffin layers should never exist unless the model is a Griffin + // model. This conditional prevents the compiler from generating code for + // this branch when building a non-Griffin model, since we have static + // asserts about the query batch size for Griffin layers. + if constexpr (TConfig::kGriffinLayers > 0) { + GriffinRecurrent( + pos, num_tokens, num_queries, layer_of_type, activations, + layer_weights, kv_caches, pool); + } + } + if (TConfig::kPostNormScale) { + RMSNormInplaceBatched( + num_tokens_and_queries, + layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data(), kModelDim); + } + AddFromBatched(num_tokens_and_queries, + activations.att_post2.data(), + activations.x.data(), kModelDim); + RMSNormBatched( + num_tokens_and_queries, activations.x.data(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), kModelDim); + FFW( + activations, num_tokens_and_queries, layer_weights, pool); + if (TConfig::kPostNormScale) { + RMSNormInplaceBatched( + num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); + } + AddFromBatched( + num_tokens_and_queries, activations.ffw_out.data(), + activations.x.data(), kModelDim); +} + +template +HWY_NOINLINE void Prefill( + const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, + const CompressedWeights& weights, + Activations& activations, + const std::vector& kv_caches, hwy::ThreadPool& pool) { + HWY_DASSERT(num_queries <= kQueryBatchSize); + const size_t minibatch_size = std::min(num_tokens, kBatchSize); + PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); + // TODO(patrickms): Try to hoist pool.Run out of the loop. + for (size_t i = 0; i < num_tokens; i += minibatch_size) { + const size_t offset = i * num_queries; + const size_t current_token_count = std::min( + minibatch_size, num_tokens - i); + pool.Run(0, current_token_count * num_queries, + [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { + EmbedToken( + tokens[token_idx + offset], token_idx, pos + offset, + weights, activations); + }); + + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + const auto* layer_weights = weights.GetLayer(layer); + TransformerLayer( + current_token_count, num_queries, pos + offset , layer, layer_weights, + activations, kv_caches, pool); + } + } +} + +// Compute the transformer for a batch of input tokens. During generation, +// we usually have num_tokens == 1 (and also kBatchSize == 1). +template +HWY_NOINLINE void Transformer( + const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, + const CompressedWeights& weights, + Activations& activations, + const std::vector& kv_caches, + hwy::ThreadPool& pool, + const LayersOutputFunc& layers_output) { + HWY_ASSERT(num_tokens <= kBatchSize); + const size_t num_tokens_and_queries = num_tokens * num_queries; + if (layers_output) { + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; + ++token_idx) { + float token_f = tokens[token_idx]; + layers_output(pos + token_idx, "Tokens", &token_f, 1); + } + } + constexpr size_t kModelDim = TConfig::kModelDim; + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { + EmbedToken( + tokens[token_idx], token_idx, pos, weights, activations); + } + + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + const CompressedLayer* layer_weights = weights.GetLayer(layer); + TransformerLayer( + num_tokens, num_queries, pos, layer, layer_weights, + activations, kv_caches, pool); + + if (layers_output) { + const std::string block_name = "blocks." + std::to_string(layer); + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; + ++token_idx) { + layers_output(pos + token_idx, block_name, + activations.x.data() + token_idx * kModelDim, kModelDim); + } + } + } + + RMSNormInplaceBatched( + num_tokens * num_queries, weights.final_norm_scale.data(), + activations.x.data(), kModelDim); + if (layers_output) { + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; + ++token_idx) { + layers_output(pos + token_idx, "final_norm", + activations.x.data() + token_idx * kModelDim, kModelDim); + } + } +} + +template +void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, + size_t& prompt_size) { + if (!TConfig::kUseLocalAttention) { + if (max_tokens > TConfig::kSeqLen) { + fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", + max_tokens, TConfig::kSeqLen); + max_tokens = static_cast(TConfig::kSeqLen); + } + } + + if (max_generated_tokens > max_tokens) { + fprintf(stderr, + "WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n", + max_generated_tokens, max_tokens); + max_generated_tokens = max_tokens - 1; + } + + if (!TConfig::kUseLocalAttention) { + if (prompt_size + max_generated_tokens > max_tokens) { + fprintf(stderr, + "WARNING: prompt_size %zu + max_generated_tokens %zu > " + "max_tokens %zu, truncating to ", + prompt_size, max_generated_tokens, max_tokens); + prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens); + fprintf(stderr, "%zu\n", prompt_size); + } + } + + HWY_ASSERT(prompt_size > 0); +} + +// Placeholder for internal test3, do not remove + +template +void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, + const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, size_t pos, + const size_t query_index_offset, + const std::vector& kv_caches, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + constexpr size_t kAdjustedPrefillBatchSize = + std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize); + static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize); + const size_t num_queries = prompts.size(); + HWY_DASSERT(num_queries <= kQueryBatchSize); + pos *= num_queries; // position in (num_queries) interleaved token sequence. + const CompressedWeights& weights = + *reinterpret_cast*>(weights_u8.get()); + auto& prefill_activations = + GetActivations(prefill_u8); + auto& activations = GetActivations(decode_u8); + + size_t min_prompt_size = (size_t)-1; + size_t max_prompt_size = 0; + for (int i=0; i < prompts.size(); ++i) { + min_prompt_size = std::min(min_prompt_size, prompts[i].size()); + max_prompt_size = std::max(max_prompt_size, prompts[i].size()); + } + + std::vector prompt; + prompt.reserve(max_prompt_size * prompts.size()); + for (int i = 0; i < max_prompt_size; ++i) { + for (int j=0; j < prompts.size(); ++j) { + if (i < prompts[j].size()) { + prompt.push_back(prompts[j][i]); + } else { + prompt.push_back(0); + } + } + } + + constexpr size_t kVocabSize = TConfig::kVocabSize; + + size_t max_tokens = runtime_config.max_tokens; + size_t max_generated_tokens = runtime_config.max_generated_tokens; + RangeChecks(max_tokens, max_generated_tokens, max_prompt_size); + if (pos >= max_tokens) { + fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, + max_tokens); + return; + } + + // If no sample_func is provided, we use top-k sampling. + const SampleFunc sample_token = + runtime_config.sample_func + ? runtime_config.sample_func + : [&](const float* logits, size_t vocab_size) -> int { + return SampleTopK(logits, vocab_size, *runtime_config.gen, + runtime_config.temperature, + runtime_config.accept_token); + }; + + std::vector reached_eos(num_queries); + std::fill(reached_eos.begin(), reached_eos.end(), false); + + // pos indexes the KV cache. In the first turn of a chat, pos = 0. + // + // After the first turn, pos gets passed in with > 0 corresponding to the + // current token position in the KV cache. + // + // pos_offset keeps track of the relative position within the turn, starting + // at 0 each turn. During prefill, pos_offset corresponds to the index into + // the prompt vector. + // + // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are + // always equal. + size_t pos_offset = 0; // offset relative to pos + // Used to keep track of how many tokens are processed per prompt, + // so that we know when to start generating tokens. + size_t single_prompt_pos_offset = 0; + const double prefill_start = hwy::platform::Now(); + + // Prefill stops before prompt_size - 1 since the last prompt token is the + // first input token for generation. + while (single_prompt_pos_offset < min_prompt_size - 1) { + const size_t batch_size = std::min( + kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset); + const size_t batch_and_query_size = batch_size * num_queries; + HWY_DASSERT(batch_size <= kPrefillBatchSize); + HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1); + HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries); + const int* batch_tokens = prompt.data() + pos_offset; + Prefill( + batch_tokens, batch_size, num_queries, pos, weights, + prefill_activations, kv_caches, pool); + for (size_t idx = 0; idx < batch_size; ++idx) { + bool all_tokens_eos = true; + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + if (reached_eos[query_idx]) continue; + if (runtime_config.StreamToken( + query_idx + query_index_offset, single_prompt_pos_offset, + batch_tokens[idx * num_queries + query_idx], 0.0f)) { + all_tokens_eos = false; + } else { + reached_eos[query_idx] = true; + } + } + if (all_tokens_eos) { + return; + } + } + pos += batch_and_query_size; + pos_offset += batch_and_query_size; + single_prompt_pos_offset += batch_size; + } + + timing_info.prefill_tok_sec = + static_cast(pos_offset) / (hwy::platform::Now() - prefill_start); + + // Start generation. + const double gen_start = hwy::platform::Now(); + HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1); + size_t pos_gen_start = pos_offset; + int token = prompt.at(pos_offset); + std::vector::const_iterator first = prompt.begin() + pos_offset; + std::vector::const_iterator last = first + num_queries; + std::vector gen_tokens(first, last); + // The loop below is not yet prepared for decode batch size > 1. + HWY_ASSERT(kDecodeBatchSize == 1); + bool all_tokens_eos = true; + for (size_t i=0; i < num_queries; ++i) { + if (reached_eos[i]) continue; + if (runtime_config.StreamToken(i + query_index_offset, + single_prompt_pos_offset, gen_tokens[i], + 0.0f)) { + all_tokens_eos = false; + } else { + reached_eos[i] = true; + } + } + if (all_tokens_eos) { + return; + } + for (size_t generate_pos = 0; + generate_pos < max_tokens && generate_pos < max_generated_tokens; + ++single_prompt_pos_offset, ++generate_pos) { + Transformer( + gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights, + activations, kv_caches, pool, runtime_config.layers_output); + float token_logit = 0.0f; + // The condition below is always true if we are doing Prefill above. + // We keep it here for clarity so that the code is correct even if Prefill + // is disabled. + bool all_tokens_eos = true; + float* x = activations.x.data(); + float* logits = activations.logits.data(); + for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset, + x += TConfig::kModelDim, logits += kVocabSize) { + const size_t prompt_size = prompts[i].size(); + const bool is_generating_phase = + (single_prompt_pos_offset >= prompt_size - 1); + if (is_generating_phase) { + PROFILER_ZONE("Gen.Embedding"); + // Compute logits from last layer activations. + MatVec( + weights.embedder_input_embedding, 0, x, activations.even_odd.data(), + logits, pool); + if constexpr (TConfig::kFinalCap > 0.0f) { + LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(), + kVocabSize); + } + // Barrier: must have all logits so we can subtract max. + Softmax(logits, kVocabSize); + token = sample_token(logits, kVocabSize); + token_logit = logits[token]; + if (generate_pos == 0) { + timing_info.time_to_first_token = hwy::platform::Now() - gen_start; + } + } else { + // We would take this branch if we were not doing Prefill but would + // process the tokens of the prompt one at a time. + token = prompt.at(pos_offset); + token_logit = 0.0f; + } + + if (!reached_eos[i]) { + if (!runtime_config.StreamToken(i + query_index_offset, + single_prompt_pos_offset + 1, token, + token_logit)) { + token = runtime_config.eos_id; + } + if (token != runtime_config.eos_id) { + all_tokens_eos = false; + } else { + reached_eos[i] = true; + } + } + gen_tokens[i] = token; + } + if (all_tokens_eos) { + break; + } + } + timing_info.gen_tok_sec = static_cast(pos_offset - pos_gen_start) / + (hwy::platform::Now() - gen_start); +} + +template +void GenerateSingleT(const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, + const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + // TODO: the input should also be span, not a vector. + const hwy::Span prompt_span(const_cast(prompt.data()), + prompt.size()); + const hwy::Span> prompts(&prompt_span, 1); + // TODO: also span of kv_cache. + std::vector kv_caches = {&kv_cache}; + const size_t query_index_offset = 0; + GenerateT( + weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos, + query_index_offset, kv_caches, pool, timing_info); +} + +template +void GenerateBatchT(const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, + const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, + size_t pos, const std::vector& kv_caches, + hwy::ThreadPool& pool, + TimingInfo& timing_info) { + // Disable query batching for Griffin models. + constexpr size_t kQueryBatchSize = + (TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; + for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) { + const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize); + const hwy::Span> current_prompts( + prompts.data() + i, num_queries); + GenerateT(weights_u8, prefill_u8, decode_u8, + runtime_config, current_prompts, + pos, i, kv_caches, pool, timing_info); + } +} + +} // namespace HWY_NAMESPACE + +#if HWY_ONCE + +// These are extern functions defined by instantiations/*.cc, which include this +// 'header' after defining GEMMA_CONFIG, which is for function overloading. +void GenerateSingle( // NOLINT(misc-definitions-in-headers) + GEMMA_CONFIG, const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, const std::vector& prompt, + size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) + (weights_u8, prefill_u8, decode_u8, runtime_config, prompt, pos, kv_cache, + pool, timing_info); +} + +void GenerateBatch( // NOLINT(misc-definitions-in-headers) + GEMMA_CONFIG, const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, size_t pos, + const std::vector& kv_caches, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) + (weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos, kv_caches, + pool, timing_info); +} + +#endif // HWY_ONCE + +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ diff --git a/gemma/gemma.cc b/gemma/gemma.cc index c82d193..e1aaf7b 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -13,956 +13,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Lightweight C++ implementation of the gemma model. +// Defines Gemma member functions; the actual implementations are in +// gemma-inl.h, included from instantiations/*.cc. -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep -// Must come after foreach_target.h to avoid redefinition errors. -#include "gemma/ops.h" -#include "hwy/contrib/matvec/matvec-inl.h" -#include "hwy/highway.h" - -// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last -// compile pass, whereas we want this defined in the first. -#ifndef GEMMA_ONCE -#define GEMMA_ONCE +#include "gemma/gemma.h" #include #include #include #include -#include -#include -#include #include // std::move #include #include "compression/io.h" // Path +#include "gemma/activations.h" #include "gemma/common.h" -#include "gemma/configs.h" -#include "gemma/gemma.h" #include "gemma/weights.h" -// Placeholder for internal test1, do not remove -// Placeholder for internal test4, do not remove -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" +#include "hwy/highway.h" namespace gcpp { -// Must be aligned. -template -struct Activations { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention - // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, - // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1); - - std::array x; // input - std::array pre_att_rms_out; - std::array - q; // query vector - std::array - att; // attention vector - std::array - att_out; // attention output - std::array - att_post1; // attention output after linear transformation, per head - std::array - att_post2; // accumulation of attention outputs over heads - std::array - bf_pre_ffw_rms_out; - std::array - ffw_hidden; - - // For FFW MatMul. - std::array C1; - std::array C2; - - // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. - // std::array - // bf_ffw_hidden; - std::array ffw_out; - std::array logits; - - // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into - // per-thread storage. - std::array even_odd; - - // Griffin layer internal activations - static constexpr size_t kGriffinDim = - TConfig::kGriffinLayers > 0 ? kModelDim : 0; - std::array griffin_x; - std::array griffin_y; - std::array griffin_gate_x; - std::array griffin_multiplier; -}; - -template -struct AllocateState { - void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { - // When batching queries, the prefill batch size is reduced by a factor - // of kBatchedQueryBatchSize - prefill = - AllocateSizeof>(); - decode = AllocateSizeof< - Activations>(); - } -}; - -template -Activations& GetActivations(const ByteStorageT& state_u8) { - return *reinterpret_cast*>(state_u8.get()); -} - -// Placeholder for internal test2, do not remove - -} // namespace gcpp -#endif // GEMMA_ONCE - -// SIMD code, compiled once per target. -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace { - -template -HWY_NOINLINE void GriffinRecurrent( - size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, - const std::vector& kv_caches, hwy::ThreadPool& pool) { - PROFILER_ZONE("Gen.Griffin"); - static_assert(kQueryBatchSize == 1, - "Griffin does not support batched queries."); - HWY_DASSERT(num_queries == 1); // TODO: add batch query support for Griffin. - KVCache& kv_cache = *kv_caches[0]; - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - HWY_DASSERT(num_tokens <= kBatchSize); - static constexpr size_t kModelDim = - gcpp::Activations::kModelDim; - static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - static constexpr size_t kHeads = TConfig::kHeads; - - // X / Y linear layers. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; - float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - TwoMatVecAdd( - layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, - activations.pre_att_rms_out.data() + batch_offset, - /*add0=*/layer_weights->griffin.linear_x_biases.data(), - /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, - /*out1=*/y, pool); - Gelu(y, kModelDim); - } - - // Conv1D. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; - const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - HWY_FULL(float) df; - HWY_DASSERT(kModelDim % hn::Lanes(df) == 0); - const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); - - // cache[i] = input at time t-i. - float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)]; - cache[0] = x; - for (size_t i = 1; i < kConv1dWidth; i++) { - cache[i] = - kv_cache.conv1d_cache.get() + layer_offset + - ((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim; - } - for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) { - auto xv = hn::Load(df, x + i); - auto accum0 = - hn::Load(df, layer_weights->griffin.conv_biases.data() + i); - auto accum1 = hn::Zero(df); - static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); - for (size_t l = 0; 2 * l < kConv1dWidth; l++) { - auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + - (kConv1dWidth - 1 - 2 * l) * kModelDim + i); - auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + - (kConv1dWidth - 2 - 2 * l) * kModelDim + i); - accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); - accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); - } - hn::Store(hn::Add(accum0, accum1), df, x + i); - hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i); - } - } - - // RGLRU - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; - const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - float* HWY_RESTRICT gate_x = - activations.griffin_gate_x.data() + batch_offset; - float* HWY_RESTRICT a = - activations.griffin_multiplier.data() + batch_offset; - float* HWY_RESTRICT rnn_state = - kv_cache.rglru_cache.get() + layer * kModelDim; - - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - constexpr size_t kHeadDim = kModelDim / kHeads; - constexpr size_t kMatrixSize = kHeadDim * kHeadDim; - size_t head_offset = head * kHeadDim; - TwoOfsMatVecAddLoop( - layer_weights->griffin.gate_w, kMatrixSize * head, - kMatrixSize * (kHeads + head), x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, - /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + - head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); - Sigmoid(gate_x + head_offset, kHeadDim); - Sigmoid(a + head_offset, kHeadDim); - const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; - hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.data() + head_offset, fn_mul); - hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, - fn_mul); - // RNN scan - HWY_FULL(float) df; - HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { - auto log_a = hn::Load(df, a + head_offset + i); - auto gated_x = hn::Load(df, x + head_offset + i); - auto rnn = hn::Load(df, rnn_state + head_offset + i); - auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); - if (pos == 0) { - x_multiplier = hn::Set(df, 1.0f); - } - auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); - hn::Store(new_x, df, rnn_state + head_offset + i); - - // Join branches. - auto yv = hn::Load(df, y + head_offset + i); - auto pre_out = hn::Mul(yv, new_x); - hn::Store(pre_out, df, x + head_offset + i); - } - }); - } - - // Final linear layer. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; - MatVecAdd( - layer_weights->griffin.linear_out_w, 0, x, - layer_weights->griffin.linear_out_biases.data(), - activations.even_odd.data(), out_ptr, pool); - } -} - -template -HWY_NOINLINE void Attention( - size_t batch_and_query_start, size_t num_tokens, size_t num_queries, - size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, - const std::vector& kv_caches, - hwy::ThreadPool& pool) { - PROFILER_ZONE("Gen.Attention"); - HWY_DASSERT(num_tokens <= kBatchSize); - HWY_DASSERT(num_queries <= kQueryBatchSize); - HWY_DASSERT(batch_and_query_start % num_queries == 0); - using TActivations = Activations; - constexpr size_t kQKVDim = TActivations::kQKVDim; - constexpr size_t kQStride = TActivations::kQStride; - constexpr size_t kCachePosSize = CachePosSize()(); - constexpr size_t kCacheLayerSize = CacheLayerSize()(); - constexpr size_t kModelDim = TActivations::kModelDim; - constexpr size_t kHeads = TConfig::kHeads; - constexpr size_t kKVHeads = TConfig::kKVHeads; - constexpr size_t kSeqLen = TConfig::kSeqLen; - GEMMA_CONSTEXPR_SQRT const float kQueryScale = - 1.0f / Sqrt(static_cast(kQKVDim)); - constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention - const size_t batch_start = batch_and_query_start / num_queries; - const size_t num_tokens_and_queries = num_tokens * num_queries; - - // If MHA, this also computes KV, which we copy to the KV cache below. - static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved - MatMul_4x4_Batch( - num_tokens_and_queries, activations.pre_att_rms_out.data(), - layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); - - for (size_t batch_and_query_idx = 0; - batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { - const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx - * kModelDim; - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; - KVCache& kv_cache = *kv_caches[query_idx]; - // QKV projections: - if constexpr (!kIsMHA) { - const size_t pos = batch_start + batch_idx; - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - // TODO: requires MatMul support for offsets. - MatVec( - layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, - activations.even_odd.data(), kv, pool); - } - } - - // Positional encodings for kv: - pool.Run( - 0, kKVHeads * num_tokens_and_queries, - [&](uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kKVHeads; - const size_t batch_and_query_idx = task / kKVHeads; - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; - const size_t pos = batch_start + batch_idx; - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize + head * kQKVDim * 2; - KVCache& kv_cache = *kv_caches[query_idx]; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - if constexpr (kIsMHA) { - // For MHA, copy kv into the KV cache from scratch space (see above). - const float* HWY_RESTRICT q = - activations.q.data() + (batch_and_query_idx * kHeads - + head) * kQStride; - // Skip past the Q part of `q`, and copy KV to `kv`. - memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); - } - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - }); - - static_assert((kHeads % kKVHeads) == 0, - "query heads must be a multiple of key-value heads"); - constexpr size_t kGroupHeads = kHeads / kKVHeads; - pool.Run(0, kHeads * num_tokens_and_queries, - [&](uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t batch_and_query_idx = task / kHeads; - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; - const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; - KVCache& kv_cache = *kv_caches[query_idx]; - float* HWY_RESTRICT q = - activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; - - const size_t pos = batch_start + batch_idx; - // Calculate scores - float* HWY_RESTRICT head_att = - activations.att.data() + head * kSeqLen - + batch_and_query_idx * kHeads * kSeqLen; - - Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - MulByConst(kQueryScale, q, kQKVDim); - - // Compute Q dot K scores - const size_t start_pos = - pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; - const float score = Dot(q, k2, kQKVDim); - head_att[pos2 % kSeqLen] = score; - } - const size_t head_att_len = std::min(pos + 1, kSeqLen); - if constexpr (TConfig::kAttCap > 0.0f) { - LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); - } - Softmax(head_att, head_att_len); - - // Weighted summation - float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + - batch_and_query_idx * kHeads * kQKVDim; - hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; - MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); - } - }); - - for (size_t batch_and_query_idx = 0; - batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { - // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after - // rearranging the weights. - float* HWY_RESTRICT att_out = - activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim; - float* HWY_RESTRICT layer_out = - activations.att_post2.data() + batch_and_query_idx * kModelDim; - MatVecT( - layer_weights->attn_vec_einsum_w, 0, att_out, - layer_weights->attention_output_biases.data(), - activations.even_odd.data(), layer_out, pool); - for (size_t head = 1; head < kHeads; ++head) { - // TODO(patrickms): Check this calculation - float* HWY_RESTRICT head_out = - activations.att_post1.data() + - head * kBatchSize * kQueryBatchSize * kModelDim; - // TODO: requires MatMul support for offsets. - MatVec( - layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, - att_out + head * kQKVDim, - activations.even_odd.data(), head_out, pool); - AddFrom(head_out, layer_out, kModelDim); - } - } -} - -template -HWY_NOINLINE void FFW(Activations& activations, - size_t num_tokens, - const CompressedLayer* layer_weights, - hwy::ThreadPool& pool) { - HWY_DASSERT(num_tokens <= kBatchSize); - constexpr size_t kModelDim = TConfig::kModelDim; - constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - float* HWY_RESTRICT even_odd = activations.even_odd.data(); - - // TODO: MatMul does not yet support adding another matrix to the result. - if constexpr (!TConfig::kFFBiases) { - PROFILER_ZONE("Gen.FFW.GatedGELU"); - - // MatMul expects col-major B, which is what we have: kModelDim consecutive - // elements in memory, repeated kFFHiddenDim times. - const auto b1 = layer_weights->gating_einsum_w.data(); - constexpr size_t kColsA = kModelDim; - constexpr size_t kColsB = kFFHiddenDim; - const auto b2 = b1 + kColsA * kColsB; - auto A = activations.bf_pre_ffw_rms_out.data(); - // Will go through GELU. - MatMul_4x4_Batch(num_tokens, A, b1, activations.C1.data(), - pool); - // What to multiply by. - MatMul_4x4_Batch(num_tokens, A, b2, activations.C2.data(), - pool); - - // Gelu and multiply by gate. - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - using VF = hn::Vec; - hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, - activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { - return hn::Mul(mul, Gelu(df, v)); - }); - - MatMul_4x4_Batch(num_tokens, activations.C1.data(), - layer_weights->linear_w.data(), - activations.ffw_out.data(), pool); - } else { // TConfig::kFFBiases == true - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; - const hwy::bfloat16_t* HWY_RESTRICT vec = - activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; - float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - - PROFILER_ZONE("Gen.FFW.GatedGELU"); - // Same matrix, first and second half of rows. Could fuse into one MatVec. - MatVecT( - layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, - out_mul, pool); - // Gate, will go through the nonlinearity. - MatVecT( - layer_weights->gating_einsum_w, 0, vec, - layer_weights->ffw_gating_biases.data(), even_odd, out, pool); - - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - using VF = hn::Vec; - hn::Transform1(DF(), out, kFFHiddenDim, out_mul, - [](DF df, VF v, VF mul) - HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); - - MatVecT( - layer_weights->linear_w, 0, - activations.ffw_hidden.data() + hidden_offset, - layer_weights->ffw_output_biases.data(), even_odd, - activations.ffw_out.data() + batch_idx * kModelDim, pool); - } - } -} - -template -HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, - const CompressedWeights& weights, - Activations& activations) { - constexpr size_t kModelDim = TConfig::kModelDim; - GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = - EmbeddingScaling(); - HWY_DASSERT(token >= 0); - HWY_DASSERT(token < TConfig::kVocabSize); - Decompress(weights.embedder_input_embedding, token * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, - kModelDim); - if constexpr (TConfig::kAbsolutePE) { - AddAbsolutePositionalEmbeddings( - activations.x.data() + token_idx * kModelDim, kModelDim, - pos + token_idx); - }; -} - -template -HWY_NOINLINE void TransformerLayer( - size_t num_tokens, size_t num_queries, size_t pos, size_t layer, - const CompressedLayer* layer_weights, - Activations& activations, - const std::vector& kv_caches, hwy::ThreadPool& pool) { - constexpr size_t kModelDim = TConfig::kModelDim; - const size_t num_tokens_and_queries = num_tokens * num_queries; - auto type = TConfig::kLayerConfig[layer]; - size_t layer_of_type = - NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNormBatched( - num_tokens_and_queries, activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); - if (type == LayerAttentionType::kGemma) { - Attention( - pos, num_tokens, num_queries, layer_of_type, activations, - layer_weights, kv_caches, pool); - } else { - // This Griffin layers should never exist unless the model is a Griffin - // model. This conditional prevents the compiler from generating code for - // this branch when building a non-Griffin model, since we have static - // asserts about the query batch size for Griffin layers. - if constexpr (TConfig::kGriffinLayers > 0) { - GriffinRecurrent( - pos, num_tokens, num_queries, layer_of_type, activations, - layer_weights, kv_caches, pool); - } - } - if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens_and_queries, - layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data(), kModelDim); - } - AddFromBatched(num_tokens_and_queries, - activations.att_post2.data(), - activations.x.data(), kModelDim); - RMSNormBatched( - num_tokens_and_queries, activations.x.data(), - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim); - FFW( - activations, num_tokens_and_queries, layer_weights, pool); - if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); - } - AddFromBatched( - num_tokens_and_queries, activations.ffw_out.data(), - activations.x.data(), kModelDim); -} - -template -HWY_NOINLINE void Prefill( - const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, - const CompressedWeights& weights, - Activations& activations, - const std::vector& kv_caches, hwy::ThreadPool& pool) { - HWY_DASSERT(num_queries <= kQueryBatchSize); - const size_t minibatch_size = std::min(num_tokens, kBatchSize); - PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); - // TODO(patrickms): Try to hoist pool.Run out of the loop. - for (size_t i = 0; i < num_tokens; i += minibatch_size) { - const size_t offset = i * num_queries; - const size_t current_token_count = std::min( - minibatch_size, num_tokens - i); - pool.Run(0, current_token_count * num_queries, - [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { - EmbedToken( - tokens[token_idx + offset], token_idx, pos + offset, - weights, activations); - }); - - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer( - current_token_count, num_queries, pos + offset , layer, layer_weights, - activations, kv_caches, pool); - } - } -} - -// Compute the transformer for a batch of input tokens. During generation, -// we usually have num_tokens == 1 (and also kBatchSize == 1). -template -HWY_NOINLINE void Transformer( - const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, - const CompressedWeights& weights, - Activations& activations, - const std::vector& kv_caches, - hwy::ThreadPool& pool, - const LayersOutputFunc& layers_output) { - HWY_ASSERT(num_tokens <= kBatchSize); - const size_t num_tokens_and_queries = num_tokens * num_queries; - if (layers_output) { - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; - ++token_idx) { - float token_f = tokens[token_idx]; - layers_output(pos + token_idx, "Tokens", &token_f, 1); - } - } - constexpr size_t kModelDim = TConfig::kModelDim; - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { - EmbedToken( - tokens[token_idx], token_idx, pos, weights, activations); - } - - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const CompressedLayer* layer_weights = weights.GetLayer(layer); - TransformerLayer( - num_tokens, num_queries, pos, layer, layer_weights, - activations, kv_caches, pool); - - if (layers_output) { - const std::string block_name = "blocks." + std::to_string(layer); - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; - ++token_idx) { - layers_output(pos + token_idx, block_name, - activations.x.data() + token_idx * kModelDim, kModelDim); - } - } - } - - RMSNormInplaceBatched( - num_tokens * num_queries, weights.final_norm_scale.data(), - activations.x.data(), kModelDim); - if (layers_output) { - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; - ++token_idx) { - layers_output(pos + token_idx, "final_norm", - activations.x.data() + token_idx * kModelDim, kModelDim); - } - } -} - -template -void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, - size_t& prompt_size) { - if (!TConfig::kUseLocalAttention) { - if (max_tokens > TConfig::kSeqLen) { - fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", - max_tokens, TConfig::kSeqLen); - max_tokens = static_cast(TConfig::kSeqLen); - } - } - - if (max_generated_tokens > max_tokens) { - fprintf(stderr, - "WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n", - max_generated_tokens, max_tokens); - max_generated_tokens = max_tokens - 1; - } - - if (!TConfig::kUseLocalAttention) { - if (prompt_size + max_generated_tokens > max_tokens) { - fprintf(stderr, - "WARNING: prompt_size %zu + max_generated_tokens %zu > " - "max_tokens %zu, truncating to ", - prompt_size, max_generated_tokens, max_tokens); - prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens); - fprintf(stderr, "%zu\n", prompt_size); - } - } - - HWY_ASSERT(prompt_size > 0); -} - -} // namespace - -// TODO(janwas): move into RuntimeConfig -bool StreamToken(size_t query_idx, size_t pos, int token, float prob, - const RuntimeConfig& runtime_config) { - if (runtime_config.batch_stream_token) { - return runtime_config.batch_stream_token(query_idx, pos, token, prob); - } - return runtime_config.stream_token(token, prob); -} - -// Placeholder for internal test3, do not remove - -template -void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, - const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, size_t pos, - const size_t query_index_offset, - const std::vector& kv_caches, hwy::ThreadPool& pool, - TimingInfo& timing_info) { - constexpr size_t kAdjustedPrefillBatchSize = - std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize); - static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize); - const size_t num_queries = prompts.size(); - HWY_DASSERT(num_queries <= kQueryBatchSize); - pos *= num_queries; // position in (num_queries) interleaved token sequence. - const CompressedWeights& weights = - *reinterpret_cast*>(weights_u8.get()); - auto& prefill_activations = - GetActivations(prefill_u8); - auto& activations = GetActivations(decode_u8); - - size_t min_prompt_size = (size_t)-1; - size_t max_prompt_size = 0; - for (int i=0; i < prompts.size(); ++i) { - min_prompt_size = std::min(min_prompt_size, prompts[i].size()); - max_prompt_size = std::max(max_prompt_size, prompts[i].size()); - } - - std::vector prompt; - prompt.reserve(max_prompt_size * prompts.size()); - for (int i = 0; i < max_prompt_size; ++i) { - for (int j=0; j < prompts.size(); ++j) { - if (i < prompts[j].size()) { - prompt.push_back(prompts[j][i]); - } else { - prompt.push_back(0); - } - } - } - - constexpr size_t kVocabSize = TConfig::kVocabSize; - - size_t max_tokens = runtime_config.max_tokens; - size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(max_tokens, max_generated_tokens, max_prompt_size); - if (pos >= max_tokens) { - fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, - max_tokens); - return; - } - - // If no sample_func is provided, we use top-k sampling. - const SampleFunc sample_token = - runtime_config.sample_func - ? runtime_config.sample_func - : [&](const float* logits, size_t vocab_size) -> int { - return SampleTopK(logits, vocab_size, *runtime_config.gen, - runtime_config.temperature, - runtime_config.accept_token); - }; - - std::vector reached_eos(num_queries); - std::fill(reached_eos.begin(), reached_eos.end(), false); - - // pos indexes the KV cache. In the first turn of a chat, pos = 0. - // - // After the first turn, pos gets passed in with > 0 corresponding to the - // current token position in the KV cache. - // - // pos_offset keeps track of the relative position within the turn, starting - // at 0 each turn. During prefill, pos_offset corresponds to the index into - // the prompt vector. - // - // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are - // always equal. - size_t pos_offset = 0; // offset relative to pos - // Used to keep track of how many tokens are processed per prompt, - // so that we know when to start generating tokens. - size_t single_prompt_pos_offset = 0; - const double prefill_start = hwy::platform::Now(); - - // Prefill stops before prompt_size - 1 since the last prompt token is the - // first input token for generation. - while (single_prompt_pos_offset < min_prompt_size - 1) { - const size_t batch_size = std::min( - kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset); - const size_t batch_and_query_size = batch_size * num_queries; - HWY_DASSERT(batch_size <= kPrefillBatchSize); - HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1); - HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries); - const int* batch_tokens = prompt.data() + pos_offset; - Prefill( - batch_tokens, batch_size, num_queries, pos, weights, - prefill_activations, kv_caches, pool); - for (size_t idx = 0; idx < batch_size; ++idx) { - bool all_tokens_eos = true; - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - if (reached_eos[query_idx]) continue; - if (StreamToken( - query_idx + query_index_offset, single_prompt_pos_offset, - batch_tokens[idx * num_queries + query_idx], 0.0f, - runtime_config)) { - all_tokens_eos = false; - } else { - reached_eos[query_idx] = true; - } - } - if (all_tokens_eos) { - return; - } - } - pos += batch_and_query_size; - pos_offset += batch_and_query_size; - single_prompt_pos_offset += batch_size; - } - - timing_info.prefill_tok_sec = - static_cast(pos_offset) / (hwy::platform::Now() - prefill_start); - - // Start generation. - const double gen_start = hwy::platform::Now(); - HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1); - size_t pos_gen_start = pos_offset; - int token = prompt.at(pos_offset); - std::vector::const_iterator first = prompt.begin() + pos_offset; - std::vector::const_iterator last = first + num_queries; - std::vector gen_tokens(first, last); - // The loop below is not yet prepared for decode batch size > 1. - HWY_ASSERT(kDecodeBatchSize == 1); - bool all_tokens_eos = true; - for (size_t i=0; i < num_queries; ++i) { - if (reached_eos[i]) continue; - if (StreamToken(i + query_index_offset, - single_prompt_pos_offset, gen_tokens[i], 0.0f, - runtime_config)) { - all_tokens_eos = false; - } else { - reached_eos[i] = true; - } - } - if (all_tokens_eos) { - return; - } - for (size_t generate_pos = 0; - generate_pos < max_tokens && generate_pos < max_generated_tokens; - ++single_prompt_pos_offset, ++generate_pos) { - Transformer( - gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights, - activations, kv_caches, pool, runtime_config.layers_output); - float token_logit = 0.0f; - // The condition below is always true if we are doing Prefill above. - // We keep it here for clarity so that the code is correct even if Prefill - // is disabled. - bool all_tokens_eos = true; - float* x = activations.x.data(); - float* logits = activations.logits.data(); - for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset, - x += TConfig::kModelDim, logits += kVocabSize) { - const size_t prompt_size = prompts[i].size(); - const bool is_generating_phase = - (single_prompt_pos_offset >= prompt_size - 1); - if (is_generating_phase) { - PROFILER_ZONE("Gen.Embedding"); - // Compute logits from last layer activations. - MatVec( - weights.embedder_input_embedding, 0, x, activations.even_odd.data(), - logits, pool); - if constexpr (TConfig::kFinalCap > 0.0f) { - LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(), - kVocabSize); - } - // Barrier: must have all logits so we can subtract max. - Softmax(logits, kVocabSize); - token = sample_token(logits, kVocabSize); - token_logit = logits[token]; - if (generate_pos == 0) { - timing_info.time_to_first_token = hwy::platform::Now() - gen_start; - } - } else { - // We would take this branch if we were not doing Prefill but would - // process the tokens of the prompt one at a time. - token = prompt.at(pos_offset); - token_logit = 0.0f; - } - - if (!reached_eos[i]) { - if (!StreamToken(i + query_index_offset, single_prompt_pos_offset+1, - token, token_logit, runtime_config)) { - token = runtime_config.eos_id; - } - if (token != runtime_config.eos_id) { - all_tokens_eos = false; - } else { - reached_eos[i] = true; - } - } - gen_tokens[i] = token; - } - if (all_tokens_eos) { - break; - } - } - timing_info.gen_tok_sec = static_cast(pos_offset - pos_gen_start) / - (hwy::platform::Now() - gen_start); -} - -template -void GenerateOneQueryT(const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, - const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info) { - std::vector> prompt_vector = { - hwy::Span(const_cast(prompt.data()), prompt.size())}; - const hwy::Span> prompts( - prompt_vector.data(), prompt_vector.size()); - std::vector kv_caches = {&kv_cache}; - GenerateT(weights_u8, prefill_u8, decode_u8, - runtime_config, prompts, pos, 0, - kv_caches, pool, timing_info); -} - -template -void GenerateBatchT(const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, - const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, - size_t pos, const std::vector& kv_caches, - hwy::ThreadPool& pool, - TimingInfo& timing_info) { - // Disable query batching for Griffin models. - constexpr size_t kQueryBatchSize = - (TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; - for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) { - const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize); - const hwy::Span> current_prompts( - prompts.data() + i, num_queries); - GenerateT(weights_u8, prefill_u8, decode_u8, - runtime_config, current_prompts, - pos, i, kv_caches, pool, timing_info); - } -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, hwy::ThreadPool& pool) : pool_(pool), tokenizer_(tokenizer_path), info_(info) { @@ -986,15 +59,61 @@ Gemma::~Gemma() { weights_u8_); } +// There are >100 instantiations of the inference code. To reduce compile time, +// we shard them across multiple translation units in instantiations/*.cc. +// This declares the functions defined there. We use overloading because +// explicit instantiations are still too slow to compile. +#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ + extern void GenerateSingle( \ + CONFIGT, const ByteStorageT& weights_u8, \ + const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \ + const RuntimeConfig& runtime_config, const std::vector& prompt, \ + size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \ + TimingInfo& timing_info); \ + extern void GenerateBatch( \ + CONFIGT, const ByteStorageT& weights_u8, \ + const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \ + const RuntimeConfig& runtime_config, \ + const hwy::Span>& prompts, size_t pos, \ + const std::vector& kv_caches, hwy::ThreadPool& pool, \ + TimingInfo& timing_info); +GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); + +// Adapters to select from the above overloads via CallForModelAndWeight. +// TODO: gather all ByteStorageT into a type-erased model struct? +template +struct GenerateSingleT { + void operator()(const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, KVCache& kv_cache, + hwy::ThreadPool& pool, TimingInfo& timing_info) const { + GenerateSingle(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config, + prompt, pos, kv_cache, pool, timing_info); + } +}; + +template +struct GenerateBatchT { + void operator()(const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, size_t pos, + const std::vector& kv_caches, hwy::ThreadPool& pool, + TimingInfo& timing_info) const { + GenerateBatch(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config, + prompts, pos, kv_caches, pool, timing_info); + } +}; + void Gemma::Generate(const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info) { pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); - GEMMA_EXPORT_AND_DISPATCH( - info_.model, info_.weight, GenerateOneQueryT, - (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, - kv_cache, pool_, timing_info)); + CallForModelAndWeight( + info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_, + runtime_config, prompt, start_pos, kv_cache, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } @@ -1006,40 +125,11 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, TimingInfo& timing_info) { pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); - GEMMA_EXPORT_AND_DISPATCH( - info_.model, info_.weight, GenerateBatchT, - (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos, - kv_caches, pool_, timing_info)); + CallForModelAndWeight( + info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_, + runtime_config, prompts, start_pos, kv_caches, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } -// TODO(janwas): move to common.h. -void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { - - // Instruction-tuned models are trained to expect control tokens. - if (info.training == ModelTraining::GEMMA_IT) { - // Prepend "" if this is a multi-turn dialogue continuation. - const std::string start = (pos == 0) - ? "user\n" - : "\nuser\n"; - prompt = start + prompt + "\nmodel\n"; - } -} - -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt) { - Wrap(info, pos, prompt); - - std::vector tokens; - HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); - // Both pre-trained and instruction-tuned require BOS as first token. - if (pos == 0) { - tokens.insert(tokens.begin(), BOS_ID); - } - return tokens; -} - } // namespace gcpp -#endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index 73bc3dc..f0091e3 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -56,6 +56,13 @@ using LayersOutputFunc = std::function; struct RuntimeConfig { + bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { + if (batch_stream_token) { + return batch_stream_token(query_idx, pos, token, prob); + } + return stream_token(token, prob); + } + size_t max_tokens; size_t max_generated_tokens; float temperature; diff --git a/gemma/instantiations/27b_bf16.cc b/gemma/instantiations/27b_bf16.cc new file mode 100644 index 0000000..bd68679 --- /dev/null +++ b/gemma/instantiations/27b_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/27b_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma27B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/27b_f32.cc b/gemma/instantiations/27b_f32.cc new file mode 100644 index 0000000..75a200d --- /dev/null +++ b/gemma/instantiations/27b_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/27b_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma27B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/27b_sfp.cc b/gemma/instantiations/27b_sfp.cc new file mode 100644 index 0000000..5a268e4 --- /dev/null +++ b/gemma/instantiations/27b_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/27b_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma27B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/2b_bf16.cc b/gemma/instantiations/2b_bf16.cc new file mode 100644 index 0000000..bd03de7 --- /dev/null +++ b/gemma/instantiations/2b_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/2b_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/2b_f32.cc b/gemma/instantiations/2b_f32.cc new file mode 100644 index 0000000..fd49571 --- /dev/null +++ b/gemma/instantiations/2b_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/2b_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/2b_sfp.cc b/gemma/instantiations/2b_sfp.cc new file mode 100644 index 0000000..c93f1d1 --- /dev/null +++ b/gemma/instantiations/2b_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/2b_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/7b_bf16.cc b/gemma/instantiations/7b_bf16.cc new file mode 100644 index 0000000..03bc369 --- /dev/null +++ b/gemma/instantiations/7b_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/7b_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma7B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/7b_f32.cc b/gemma/instantiations/7b_f32.cc new file mode 100644 index 0000000..7f09e85 --- /dev/null +++ b/gemma/instantiations/7b_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/7b_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma7B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/7b_sfp.cc b/gemma/instantiations/7b_sfp.cc new file mode 100644 index 0000000..78a768d --- /dev/null +++ b/gemma/instantiations/7b_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/7b_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma7B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_bf16.cc b/gemma/instantiations/9b_bf16.cc new file mode 100644 index 0000000..df72954 --- /dev/null +++ b/gemma/instantiations/9b_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/9b_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma9B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_f32.cc b/gemma/instantiations/9b_f32.cc new file mode 100644 index 0000000..26f0eed --- /dev/null +++ b/gemma/instantiations/9b_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/9b_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma9B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_sfp.cc b/gemma/instantiations/9b_sfp.cc new file mode 100644 index 0000000..17aefb1 --- /dev/null +++ b/gemma/instantiations/9b_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/9b_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma9B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gr2b_bf16.cc b/gemma/instantiations/gr2b_bf16.cc new file mode 100644 index 0000000..4c0b36e --- /dev/null +++ b/gemma/instantiations/gr2b_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/gr2b_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGriffin2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gr2b_f32.cc b/gemma/instantiations/gr2b_f32.cc new file mode 100644 index 0000000..8d12b1a --- /dev/null +++ b/gemma/instantiations/gr2b_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/gr2b_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGriffin2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gr2b_sfp.cc b/gemma/instantiations/gr2b_sfp.cc new file mode 100644 index 0000000..32f40ff --- /dev/null +++ b/gemma/instantiations/gr2b_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/gr2b_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGriffin2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/tiny_bf16.cc b/gemma/instantiations/tiny_bf16.cc new file mode 100644 index 0000000..53dad72 --- /dev/null +++ b/gemma/instantiations/tiny_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/tiny_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemmaTiny +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/tiny_f32.cc b/gemma/instantiations/tiny_f32.cc new file mode 100644 index 0000000..6f11ddc --- /dev/null +++ b/gemma/instantiations/tiny_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/tiny_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemmaTiny +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/tiny_sfp.cc b/gemma/instantiations/tiny_sfp.cc new file mode 100644 index 0000000..2eaa86f --- /dev/null +++ b/gemma/instantiations/tiny_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/tiny_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemmaTiny +#include "gemma/gemma-inl.h" diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index 0142573..07b3f98 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -22,7 +22,8 @@ #include #include "compression/io.h" // Path -#include "hwy/base.h" +#include "gemma/common.h" // Wrap +#include "hwy/base.h" // HWY_ASSERT #include "hwy/profiler.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" @@ -95,4 +96,18 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const ModelInfo& info, size_t pos, + std::string& prompt) { + Wrap(info, pos, prompt); + + std::vector tokens; + HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); + // Both pre-trained and instruction-tuned require BOS as first token. + if (pos == 0) { + tokens.insert(tokens.begin(), BOS_ID); + } + return tokens; +} + } // namespace gcpp diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index f42daa7..f0bb0fc 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -16,11 +16,14 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_ +#include + #include #include #include #include "compression/io.h" // Path +#include "gemma/common.h" // ModelInfo namespace gcpp { @@ -47,6 +50,10 @@ class GemmaTokenizer { std::unique_ptr impl_; }; +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const ModelInfo& info, size_t pos, + std::string& prompt); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_