standard swa method can be used instead of a new enum being LLAMA_SWA_TYPE_LOCAL

This commit is contained in:
ryan-mangeno 2025-09-26 14:12:15 -04:00
parent 35667f27b3
commit 3cdd6503bd
4 changed files with 17 additions and 17 deletions

View File

@ -263,8 +263,7 @@ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_s
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" :
(swa_type == LLAMA_SWA_TYPE_LOCAL) ? "LLAMA_SWA_TYPE_LOCAL" : "unknown";
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);

View File

@ -20,7 +20,6 @@ enum llama_swa_type {
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_SYMMETRIC = 3,
LLAMA_SWA_TYPE_LOCAL = 4,
};
struct llama_hparams_posnet {

View File

@ -786,7 +786,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_MODERN_BERT:
{
hparams.swa_type = LLAMA_SWA_TYPE_LOCAL;
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(3, 0);
hparams.rope_freq_base_train_swa = 10000.f;
@ -7845,15 +7845,16 @@ template <bool iswa>
struct llm_build_modern_bert : public llm_graph_context {
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
const float rope_theta_global = hparams.rope_freq_base_train;
const float rope_theta_local = hparams.rope_freq_base_train_swa;
//const uint32_t n_swa_local = hparams.n_swa;
//const uint32_t n_swa_global = 4096;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos()
ggml_tensor * cur = nullptr;
ggml_tensor * inpL = nullptr;
ggml_tensor * inp_pos = build_inp_pos();
// construct input embeddings (token, type, position)
inpL = build_inp_embd(model.tok_embd);
@ -7869,11 +7870,12 @@ struct llm_build_modern_bert : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * cur = inpL;
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;
const float rope_theta = (il+1) % 3 == 0 ? rope_theta_global : rope_theta_local;
const float rope_theta = il % 3 == 0 ? rope_theta_global : rope_theta_local;
// attention layer norm
if (model.layers[il].attn_norm) {
@ -7887,9 +7889,9 @@ struct llm_build_modern_bert : public llm_graph_context {
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd*2)));
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);

View File

@ -1864,8 +1864,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de" ||
tokenizer_pre == "a.x-4.0" ||
tokenizer_pre == "mellum" ||
tokenizer_pre == "modern-bert") {
tokenizer_pre == "mellum" ||
tokenizer_pre == "modern-bert" ) {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if (
tokenizer_pre == "jina-v1-en" ||