// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Model configurations #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ // Allow changing pre-allocated kv cache size as a compiler flag #ifndef GEMMA_MAX_SEQLEN #define GEMMA_MAX_SEQLEN 4096 #endif // !GEMMA_MAX_SEQLEN // Allow changing k parameter of `SampleTopK` as a compiler flag #ifndef GEMMA_TOPK #define GEMMA_TOPK 1 #endif // !GEMMA_TOPK #include // copybara:import_next_line:gemma_cpp #include "compression/sfp.h" #include "hwy/base.h" // hwy::bfloat16_t // Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time): // float, hwy::bfloat16_t, SfpStream, NuqStream #ifndef GEMMA_WEIGHT_T #define GEMMA_WEIGHT_T SfpStream #endif // !GEMMA_WEIGHT_T namespace gcpp { static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kTopK = GEMMA_TOPK; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kVocabSize = 256000; static constexpr int kLayers = 28; static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; static constexpr int kKVHeads = 16; // standard MHA static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr int kNumTensorScales = 0; using WeightT = GEMMA_WEIGHT_T; }; struct ConfigGemma2B { static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kVocabSize = 256000; static constexpr int kLayers = 18; static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr int kNumTensorScales = 0; using WeightT = GEMMA_WEIGHT_T; }; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_