Merge pull request #139 from ufownl:feature/public_layers

PiperOrigin-RevId: 623254705
This commit is contained in:
Copybara-Service 2024-04-09 12:54:23 -07:00
commit 827fec1904
3 changed files with 37 additions and 32 deletions

View File

@ -61,12 +61,29 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
return config;
}
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
@ -91,6 +108,12 @@ struct ConfigGemma2B {
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
@ -143,6 +166,12 @@ struct ConfigGriffin2B {
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;

View File

@ -71,30 +71,6 @@ constexpr bool kShowTokenization = false;
namespace gcpp {
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <typename TConfig>
constexpr size_t NumGemmaLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGemma, TConfig::kLayers);
}
template <typename TConfig>
constexpr size_t NumGriffinLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
TConfig::kLayers);
}
template <class TConfig>
struct Layer {
Layer() = default;
@ -389,7 +365,7 @@ struct Activations {
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCachePosSize =
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim;
TConfig::kGemmaLayers * kKVHeads * kQKVDim;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input
@ -443,11 +419,11 @@ template <class Config>
KVCache CreateKVCache() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache(
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim,
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen,
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim,
NumGriffinLayers<Config>() * Config::kModelDim);
Config::kGriffinLayers * Config::kModelDim);
}
KVCache CreateKVCache(Model type) {

View File

@ -37,13 +37,13 @@ constexpr bool kSystemPrompt = false;
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kNumGriffinLayers
rglru_cache; // kModelDim * kGriffinLayers
};
// Model variants: see configs.h for details.