standard swa method can be used instead of a new enum being LLAMA_SWA_TYPE_LOCAL
This commit is contained in:
parent
35667f27b3
commit
3cdd6503bd
|
|
@ -263,8 +263,7 @@ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_s
|
|||
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
||||
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
||||
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
||||
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" :
|
||||
(swa_type == LLAMA_SWA_TYPE_LOCAL) ? "LLAMA_SWA_TYPE_LOCAL" : "unknown";
|
||||
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ enum llama_swa_type {
|
|||
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||
LLAMA_SWA_TYPE_SYMMETRIC = 3,
|
||||
LLAMA_SWA_TYPE_LOCAL = 4,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
|
|
|
|||
|
|
@ -786,7 +786,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
case LLM_ARCH_MODERN_BERT:
|
||||
{
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_LOCAL;
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
|
||||
hparams.set_swa_pattern(3, 0);
|
||||
hparams.rope_freq_base_train_swa = 10000.f;
|
||||
|
|
@ -7845,15 +7845,16 @@ template <bool iswa>
|
|||
struct llm_build_modern_bert : public llm_graph_context {
|
||||
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
const float rope_theta_global = hparams.rope_freq_base_train;
|
||||
const float rope_theta_local = hparams.rope_freq_base_train_swa;
|
||||
//const uint32_t n_swa_local = hparams.n_swa;
|
||||
//const uint32_t n_swa_global = 4096;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos()
|
||||
ggml_tensor * cur = nullptr;
|
||||
ggml_tensor * inpL = nullptr;
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
|
@ -7869,11 +7870,12 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * cur = inpL;
|
||||
|
||||
ggml_tensor * Qcur;
|
||||
ggml_tensor * Kcur;
|
||||
ggml_tensor * Vcur;
|
||||
ggml_tensor * Qcur = nullptr;
|
||||
ggml_tensor * Kcur = nullptr;
|
||||
ggml_tensor * Vcur = nullptr;
|
||||
|
||||
const float rope_theta = (il+1) % 3 == 0 ? rope_theta_global : rope_theta_local;
|
||||
|
||||
const float rope_theta = il % 3 == 0 ? rope_theta_global : rope_theta_local;
|
||||
|
||||
// attention layer norm
|
||||
if (model.layers[il].attn_norm) {
|
||||
|
|
@ -7887,9 +7889,9 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd*2)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
|
|
|||
|
|
@ -1864,8 +1864,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "jina-v2-es" ||
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "a.x-4.0" ||
|
||||
tokenizer_pre == "mellum" ||
|
||||
tokenizer_pre == "modern-bert") {
|
||||
tokenizer_pre == "mellum" ||
|
||||
tokenizer_pre == "modern-bert" ) {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
|
|
|
|||
Loading…
Reference in New Issue