From 4657fb59e256ef8d7442548e1406e76882d1de38 Mon Sep 17 00:00:00 2001 From: alielfilali01 Date: Tue, 10 Feb 2026 08:44:35 +0000 Subject: [PATCH 1/5] model: add JAIS-2 architecture support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for the JAIS-2 family of Arabic-English bilingual models from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat). Architecture characteristics: - LayerNorm (not RMSNorm) with biases - ReLU² (ReLU squared) activation function - Separate Q/K/V projections with biases - Simple MLP without gate projection (up -> act -> down) - RoPE positional embeddings - GPT-2 BPE tokenizer Supported model sizes: - Jais-2-8B (32 layers, 26 heads, 3328 hidden) - Jais-2-70B (68 layers, 56 heads, 7168 hidden) Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K Note: JAIS-2 requires F32 precision accumulators for numerical stability and uses standard attention (not flash attention) on CUDA backends. --- convert_hf_to_gguf.py | 29 +++++++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 15 +++++ src/CMakeLists.txt | 1 + src/llama-arch.cpp | 15 +++++ src/llama-arch.h | 1 + src/llama-graph.cpp | 12 ++-- src/llama-model.cpp | 54 ++++++++++++++++ src/llama-vocab.cpp | 3 +- src/models/jais2.cpp | 122 +++++++++++++++++++++++++++++++++++ src/models/models.h | 4 ++ 11 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 src/models/jais2.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2afaf85fb8..9e60560a37 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1087,6 +1087,9 @@ class TextModel(ModelBase): if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" + if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a": + # ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat + res = "jais-2" if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base res = "deepseek-llm" @@ -8521,6 +8524,32 @@ class T5EncoderModel(TextModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Jais2ForCausalLM") +class Jais2Model(TextModel): + model_arch = gguf.MODEL_ARCH.JAIS2 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count(head_dim) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Jais2 uses LLaMA-style RoPE (rotate_half), requiring Q/K permutation + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads", n_head) + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register("JAISLMHeadModel") class JaisModel(TextModel): model_arch = gguf.MODEL_ARCH.JAIS diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index a683451508..041b06c4ce 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -113,6 +113,7 @@ models = [ {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, + {"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", }, {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9dab0df08a..b291d6bda6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -429,6 +429,7 @@ class MODEL_ARCH(IntEnum): T5 = auto() T5ENCODER = auto() JAIS = auto() + JAIS2 = auto() NEMOTRON = auto() NEMOTRON_H = auto() NEMOTRON_H_MOE = auto() @@ -862,6 +863,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.JAIS2: "jais2", MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.NEMOTRON_H: "nemotron_h", MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe", @@ -2751,6 +2753,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.JAIS2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.NEMOTRON: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fdda05d3ea..2073282da4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -83,6 +83,7 @@ add_library(llama models/hunyuan-moe.cpp models/internlm2.cpp models/jais.cpp + models/jais2.cpp models/jamba.cpp models/kimi-linear.cpp models/lfm2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a943d40dc4..c481fb7637 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -78,6 +78,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_JAIS2, "jais2" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, @@ -1735,6 +1736,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, }; + case LLM_ARCH_JAIS2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_NEMOTRON_H: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/src/llama-arch.h b/src/llama-arch.h index 4f7b51e70d..533ab61589 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -82,6 +82,7 @@ enum llm_arch { LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_JAIS2, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H_MOE, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index bba747d37b..9c5e50990f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1099,8 +1099,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1695,7 +1695,9 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - if (cparams.flash_attn && kq_b == nullptr) { + // JAIS2 disabled: non-power-of-2 head count (26/56) causes numerical instability in flash attention + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr && arch != LLM_ARCH_JAIS2; + if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1958,8 +1960,8 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5816e9a954..6adec5eebb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1884,6 +1884,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAIS2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -5315,6 +5325,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); } } break; + case LLM_ARCH_JAIS2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } + } break; case LLM_ARCH_CHATGLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8380,6 +8429,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_JAIS2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -8737,6 +8790,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_JAIS2: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: return LLAMA_ROPE_TYPE_NORM; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 62e137fb84..7b2a7ed128 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1912,7 +1912,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || tokenizer_pre == "mellum" || - tokenizer_pre == "modern-bert" ) { + tokenizer_pre == "modern-bert" || + tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || diff --git a/src/models/jais2.cpp b/src/models/jais2.cpp new file mode 100644 index 0000000000..826f9ef4da --- /dev/null +++ b/src/models/jais2.cpp @@ -0,0 +1,122 @@ +#include "models.h" + +// JAIS-2 model graph builder +// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings +llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // KV input for attention + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // Pre-attention LayerNorm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // Self-attention with separate Q, K, V projections + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur_bias", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur_bias", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur_bias", il); + + // Reshape for attention + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Residual connection + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // Pre-FFN LayerNorm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + // FFN with relu2 activation (ReLU squared) - no gate projection + // up -> relu2 -> down + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, // no gate + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Residual connection + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final LayerNorm + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 3c66d32531..9a92035c84 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -284,6 +284,10 @@ struct llm_build_jais : public llm_graph_context { llm_build_jais(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_jais2 : public llm_graph_context { + llm_build_jais2(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_jamba : public llm_graph_context_mamba { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; From e7c25294179d3d72a72a7b298d6f9d3663a1b0fb Mon Sep 17 00:00:00 2001 From: alielfilali01 Date: Thu, 12 Feb 2026 08:21:27 +0000 Subject: [PATCH 2/5] fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9e60560a37..4bb9d594e8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1087,9 +1087,6 @@ class TextModel(ModelBase): if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" - if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a": - # ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat - res = "jais-2" if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base res = "deepseek-llm" @@ -1162,6 +1159,9 @@ class TextModel(ModelBase): if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901": # ref: https://huggingface.co/core42/jais-13b res = "jais" + if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a": + # ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat + res = "jais-2" if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f": # ref: https://huggingface.co/WisdomShell/CodeShell-7B res = "codeshell" From 56571c32385930d6abc18a52ed6fe795f41a3f70 Mon Sep 17 00:00:00 2001 From: alielfilali01 Date: Thu, 12 Feb 2026 08:22:12 +0000 Subject: [PATCH 3/5] fix: use NEOX RoPE type for JAIS2 --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6adec5eebb..0eb6f71e94 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8790,7 +8790,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: - case LLM_ARCH_JAIS2: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: return LLAMA_ROPE_TYPE_NORM; @@ -8841,6 +8840,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_JAIS2: case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: From d9a442f60221486e3abfc9387e61f33ed4b3c4e8 Mon Sep 17 00:00:00 2001 From: alielfilali01 Date: Thu, 12 Feb 2026 08:22:44 +0000 Subject: [PATCH 4/5] fix: remove Q/K permutation (NEOX RoPE doesn't need it) --- convert_hf_to_gguf.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4bb9d594e8..8f5812877f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8537,18 +8537,6 @@ class Jais2Model(TextModel): head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"]) self.gguf_writer.add_rope_dimension_count(head_dim) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # Jais2 uses LLaMA-style RoPE (rotate_half), requiring Q/K permutation - n_head = self.hparams["num_attention_heads"] - n_kv_head = self.hparams.get("num_key_value_heads", n_head) - - if name.endswith(("q_proj.weight", "q_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight", "k_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - - return [(self.map_tensor_name(name), data_torch)] - @ModelBase.register("JAISLMHeadModel") class JaisModel(TextModel): From cbe37e3b67947548b58dd4ecc4ba55826aa2f796 Mon Sep 17 00:00:00 2001 From: alielfilali01 Date: Thu, 12 Feb 2026 08:23:14 +0000 Subject: [PATCH 5/5] fix: enable flash attention for JAIS2 (fixed by #19115) --- src/llama-graph.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9c5e50990f..ad0c354ba9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1695,8 +1695,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - // JAIS2 disabled: non-power-of-2 head count (26/56) causes numerical instability in flash attention - const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr && arch != LLM_ARCH_JAIS2; + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");