From 130e1f678fbc5cfb6c17a7f15e7c882ff21a4d46 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Tue, 19 Mar 2024 22:00:52 +0800 Subject: [PATCH] Adjust vocab size to be the same as gemma_pytorch --- configs.h | 4 ++-- util/convert_weights.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs.h b/configs.h index 7b420b5..58c053f 100644 --- a/configs.h +++ b/configs.h @@ -37,7 +37,7 @@ static constexpr size_t kTopK = GEMMA_TOPK; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256128; + static constexpr int kVocabSize = 256000; static constexpr int kLayers = 28; static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 @@ -49,7 +49,7 @@ struct ConfigGemma7B { struct ConfigGemma2B { static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256128; + static constexpr int kVocabSize = 256000; static constexpr int kLayers = 18; static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 diff --git a/util/convert_weights.py b/util/convert_weights.py index bd6750a..6552d89 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -90,7 +90,7 @@ TRANSFORMATIONS = { "2b":defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), + "embedder.weight": lambda x: x, "self_attn.qkv_proj.weight": expand_qkv, "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], @@ -101,7 +101,7 @@ TRANSFORMATIONS = { "7b":defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), + "embedder.weight": lambda x: x, "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], @@ -113,7 +113,7 @@ TRANSFORMATIONS = { VALIDATIONS = { "2b": { - "embedder.weight": lambda x: x.shape == (256128, 2048), + "embedder.weight": lambda x: x.shape == (256000, 2048), "model.norm.weight": lambda x: x.shape == (2048,), "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), @@ -124,7 +124,7 @@ VALIDATIONS = { "post_attention_layernorm.weight": lambda x: x.shape == (2048,), }, "7b": { - "embedder.weight": lambda x: x.shape == (256128, 3072), + "embedder.weight": lambda x: x.shape == (256000, 3072), "model.norm.weight": lambda x: x.shape == (3072,), "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),