Lint fixes: strcat, includes, arg naming

PiperOrigin-RevId: 623435210
This commit is contained in:
Jan Wassenberg 2024-04-10 03:12:02 -07:00 committed by Copybara-Service
parent da91f4c4be
commit 881eeffe0a
1 changed files with 19 additions and 21 deletions

View File

@ -23,13 +23,8 @@
// Must come after foreach_target.h to avoid redefinition errors. // Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "gemma/ops.h" #include "gemma/ops.h"
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first. // compile pass, whereas we want this defined in the first.
@ -56,9 +51,14 @@
#include "compression/compress.h" #include "compression/compress.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/args.h" // Path
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// Setting this to true disables fread() calls that read the model file. // Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false; constexpr bool kDryRunFread = false;
@ -416,7 +416,7 @@ struct GemmaInterface {
}; };
template <class Config> template <class Config>
KVCache CreateKVCache() { KVCache CreateKVCacheT() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth; constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache( return CreateKVCache(
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim, Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
@ -429,11 +429,11 @@ KVCache CreateKVCache() {
KVCache CreateKVCache(Model type) { KVCache CreateKVCache(Model type) {
switch (type) { switch (type) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
return CreateKVCache<ConfigGemma2B>(); return CreateKVCacheT<ConfigGemma2B>();
case Model::GEMMA_7B: case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>(); return CreateKVCacheT<ConfigGemma7B>();
case Model::GRIFFIN_2B: case Model::GRIFFIN_2B:
return CreateKVCache<ConfigGriffin2B>(); return CreateKVCacheT<ConfigGriffin2B>();
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
} }
@ -1407,17 +1407,17 @@ HWY_EXPORT(ComputeCrossEntropy7B);
HWY_EXPORT(ComputeCrossEntropyGriffin2B); HWY_EXPORT(ComputeCrossEntropyGriffin2B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv_cache_size, size_t rglru_cache_size) { size_t conv1d_cache_size, size_t rglru_cache_size) {
KVCache kv_cache = {}; KVCache kv_cache = {};
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = kv_cache.value_cache =
hwy::AllocateAligned<float>(seq_len * size_cache_pos); hwy::AllocateAligned<float>(seq_len * size_cache_pos);
} }
if (conv_cache_size != 0) { if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv_cache_size); kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
hwy::ZeroBytes(kv_cache.conv1d_cache.get(), hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
conv_cache_size * sizeof(kv_cache.conv1d_cache[0])); conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
} }
if (rglru_cache_size != 0) { if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size); kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
@ -1589,16 +1589,14 @@ constexpr ModelTraining kModelTraining[] = {
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) { Model& model, ModelTraining& training) {
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024]; static char kErrorMessageBuffer[kNum * 8 + 1024] =
kErrorMessageBuffer[0] = 0; "Invalid or missing model flag, need to specify one of ";
strcat(kErrorMessageBuffer,
"Invalid or missing model flag, need to specify one of ");
for (size_t i = 0; i + 1 < kNum; i++) { for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kModelFlags[i]); strcat(kErrorMessageBuffer, kModelFlags[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); strcat(kErrorMessageBuffer, ", "); // NOLINT
} }
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); strcat(kErrorMessageBuffer, "."); // NOLINT
std::string model_type_lc = model_flag; std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); }); [](unsigned char c) { return std::tolower(c); });