mirror of https://github.com/google/gemma.cpp.git
Adjust vocab size to be the same as gemma_pytorch
This commit is contained in:
parent
5e0cafbdc2
commit
130e1f678f
|
|
@ -37,7 +37,7 @@ static constexpr size_t kTopK = GEMMA_TOPK;
|
||||||
|
|
||||||
struct ConfigGemma7B {
|
struct ConfigGemma7B {
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256128;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr int kLayers = 28;
|
static constexpr int kLayers = 28;
|
||||||
static constexpr int kModelDim = 3072;
|
static constexpr int kModelDim = 3072;
|
||||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||||
|
|
@ -49,7 +49,7 @@ struct ConfigGemma7B {
|
||||||
|
|
||||||
struct ConfigGemma2B {
|
struct ConfigGemma2B {
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256128;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr int kLayers = 18;
|
static constexpr int kLayers = 18;
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ TRANSFORMATIONS = {
|
||||||
"2b":defaultdict(
|
"2b":defaultdict(
|
||||||
lambda: lambda x: x,
|
lambda: lambda x: x,
|
||||||
{
|
{
|
||||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
"embedder.weight": lambda x: x,
|
||||||
"self_attn.qkv_proj.weight": expand_qkv,
|
"self_attn.qkv_proj.weight": expand_qkv,
|
||||||
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
|
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
|
||||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
|
|
@ -101,7 +101,7 @@ TRANSFORMATIONS = {
|
||||||
"7b":defaultdict(
|
"7b":defaultdict(
|
||||||
lambda: lambda x: x,
|
lambda: lambda x: x,
|
||||||
{
|
{
|
||||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
|
"embedder.weight": lambda x: x,
|
||||||
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
|
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
|
||||||
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
|
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
|
||||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
|
|
@ -113,7 +113,7 @@ TRANSFORMATIONS = {
|
||||||
|
|
||||||
VALIDATIONS = {
|
VALIDATIONS = {
|
||||||
"2b": {
|
"2b": {
|
||||||
"embedder.weight": lambda x: x.shape == (256128, 2048),
|
"embedder.weight": lambda x: x.shape == (256000, 2048),
|
||||||
"model.norm.weight": lambda x: x.shape == (2048,),
|
"model.norm.weight": lambda x: x.shape == (2048,),
|
||||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
|
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
|
||||||
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
||||||
|
|
@ -124,7 +124,7 @@ VALIDATIONS = {
|
||||||
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
||||||
},
|
},
|
||||||
"7b": {
|
"7b": {
|
||||||
"embedder.weight": lambda x: x.shape == (256128, 3072),
|
"embedder.weight": lambda x: x.shape == (256000, 3072),
|
||||||
"model.norm.weight": lambda x: x.shape == (3072,),
|
"model.norm.weight": lambda x: x.shape == (3072,),
|
||||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
|
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
|
||||||
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
|
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue