diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 06e0645352..f8caad2889 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2058,7 +2058,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_embd) { case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; case 2048: case 2560: type = LLM_TYPE_3B; break; case 4096: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN;