Merge pull request #139 from ufownl:feature/public_layers

PiperOrigin-RevId: 623254705
This commit is contained in:
Copybara-Service 2024-04-09 12:54:23 -07:00
commit 827fec1904
3 changed files with 37 additions and 32 deletions

View File

@ -61,12 +61,29 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
return config; return config;
} }
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
struct ConfigGemma7B { struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig = static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma); FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16; static constexpr int kHeads = 16;
@ -91,6 +108,12 @@ struct ConfigGemma2B {
static constexpr std::array<LayerAttentionType, 18> kLayerConfig = static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma); FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8; static constexpr int kHeads = 8;
@ -143,6 +166,12 @@ struct ConfigGriffin2B {
LayerAttentionType::kGriffinRecurrentBlock, LayerAttentionType::kGriffinRecurrentBlock,
}; };
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560; static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680; static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10; static constexpr int kHeads = 10;

View File

@ -71,30 +71,6 @@ constexpr bool kShowTokenization = false;
namespace gcpp { namespace gcpp {
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <typename TConfig>
constexpr size_t NumGemmaLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGemma, TConfig::kLayers);
}
template <typename TConfig>
constexpr size_t NumGriffinLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
TConfig::kLayers);
}
template <class TConfig> template <class TConfig>
struct Layer { struct Layer {
Layer() = default; Layer() = default;
@ -389,7 +365,7 @@ struct Activations {
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCachePosSize = static constexpr size_t kCachePosSize =
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim; TConfig::kGemmaLayers * kKVHeads * kQKVDim;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input std::array<float, kBatchSize * kModelDim> x; // input
@ -443,11 +419,11 @@ template <class Config>
KVCache CreateKVCache() { KVCache CreateKVCache() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth; constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache( return CreateKVCache(
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim, Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen, Config::kSeqLen,
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim, Config::kModelDim,
NumGriffinLayers<Config>() * Config::kModelDim); Config::kGriffinLayers * Config::kModelDim);
} }
KVCache CreateKVCache(Model type) { KVCache CreateKVCache(Model type) {

View File

@ -37,13 +37,13 @@ constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kNumGriffinLayers rglru_cache; // kModelDim * kGriffinLayers
}; };
// Model variants: see configs.h for details. // Model variants: see configs.h for details.