Adjust vocab size to be the same as gemma_pytorch

This commit is contained in:
RangerUFO 2024-03-19 22:00:52 +08:00
parent 5e0cafbdc2
commit 130e1f678f
2 changed files with 6 additions and 6 deletions

View File

@ -37,7 +37,7 @@ static constexpr size_t kTopK = GEMMA_TOPK;
struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256128;
static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 28;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
@ -49,7 +49,7 @@ struct ConfigGemma7B {
struct ConfigGemma2B {
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256128;
static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 18;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384

View File

@ -90,7 +90,7 @@ TRANSFORMATIONS = {
"2b":defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": expand_qkv,
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
@ -101,7 +101,7 @@ TRANSFORMATIONS = {
"7b":defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
@ -113,7 +113,7 @@ TRANSFORMATIONS = {
VALIDATIONS = {
"2b": {
"embedder.weight": lambda x: x.shape == (256128, 2048),
"embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
@ -124,7 +124,7 @@ VALIDATIONS = {
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
},
"7b": {
"embedder.weight": lambda x: x.shape == (256128, 3072),
"embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),