diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b64cdea..00e98ed 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -23,13 +23,8 @@ // Must come after foreach_target.h to avoid redefinition errors. #include "compression/compress-inl.h" #include "gemma/ops.h" -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" -// copybara:import_next_line:sentencepiece -#include "src/sentencepiece_processor.h" // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // compile pass, whereas we want this defined in the first. @@ -56,9 +51,14 @@ #include "compression/compress.h" #include "gemma/configs.h" #include "gemma/gemma.h" +#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" +// copybara:import_next_line:sentencepiece +#include "src/sentencepiece_processor.h" // Setting this to true disables fread() calls that read the model file. constexpr bool kDryRunFread = false; @@ -416,7 +416,7 @@ struct GemmaInterface { }; template -KVCache CreateKVCache() { +KVCache CreateKVCacheT() { constexpr size_t kConv1dWidth = Config::kConv1dWidth; return CreateKVCache( Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim, @@ -429,11 +429,11 @@ KVCache CreateKVCache() { KVCache CreateKVCache(Model type) { switch (type) { case Model::GEMMA_2B: - return CreateKVCache(); + return CreateKVCacheT(); case Model::GEMMA_7B: - return CreateKVCache(); + return CreateKVCacheT(); case Model::GRIFFIN_2B: - return CreateKVCache(); + return CreateKVCacheT(); default: HWY_ABORT("Model type %d unknown.", static_cast(type)); } @@ -1407,17 +1407,17 @@ HWY_EXPORT(ComputeCrossEntropy7B); HWY_EXPORT(ComputeCrossEntropyGriffin2B); KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, - size_t conv_cache_size, size_t rglru_cache_size) { + size_t conv1d_cache_size, size_t rglru_cache_size) { KVCache kv_cache = {}; if (size_cache_pos != 0) { kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); } - if (conv_cache_size != 0) { - kv_cache.conv1d_cache = hwy::AllocateAligned(conv_cache_size); + if (conv1d_cache_size != 0) { + kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); hwy::ZeroBytes(kv_cache.conv1d_cache.get(), - conv_cache_size * sizeof(kv_cache.conv1d_cache[0])); + conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0])); } if (rglru_cache_size != 0) { kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); @@ -1589,16 +1589,14 @@ constexpr ModelTraining kModelTraining[] = { const char* ParseModelTypeAndTraining(const std::string& model_flag, Model& model, ModelTraining& training) { constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); - static char kErrorMessageBuffer[kNum * 8 + 1024]; - kErrorMessageBuffer[0] = 0; - strcat(kErrorMessageBuffer, - "Invalid or missing model flag, need to specify one of "); + static char kErrorMessageBuffer[kNum * 8 + 1024] = + "Invalid or missing model flag, need to specify one of "; for (size_t i = 0; i + 1 < kNum; i++) { - strcat(kErrorMessageBuffer, kModelFlags[i]); - strcat(kErrorMessageBuffer, ", "); + strcat(kErrorMessageBuffer, kModelFlags[i]); // NOLINT + strcat(kErrorMessageBuffer, ", "); // NOLINT } - strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); - strcat(kErrorMessageBuffer, "."); + strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT + strcat(kErrorMessageBuffer, "."); // NOLINT std::string model_type_lc = model_flag; std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), [](unsigned char c) { return std::tolower(c); });