llama : add support for EmbeddingGemma 300m (#15798)

This commit add support for the EmbeddingGemma 300m. This model supports
sliding window attention (SWA) and a new swq_type is introduced to
support symmetric SWA masking.

This commit also extracts the code from the function
llama_is_masked_swa in llama-impl.h, so that the logic can be shared
by both llm_graph_input_attn_no_cache::set_input and
llama_kv_cache::set_input_kq_mask.

With this commit the EmbeddingGemma 300m model can be converted to
to GGUF and used with llama.cpp.

Once the model has been uploaded to HuggingFace it can be used like
this:
```console
./build/bin/llama-cli -hf ggml-org/embeddinggemma-300m-GGUF:Q8_0
```
This commit is contained in:
Daniel Bevenius 2025-09-04 18:10:29 +02:00 committed by GitHub
parent 856ed0947f
commit fb15d649ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 328 additions and 47 deletions

View File

@ -5122,6 +5122,15 @@ class Gemma3Model(TextModel):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Gemma3TextModel")
class EmbeddingGemma(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
@ModelBase.register("Gemma3ForConditionalGeneration") @ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(MmprojModel): class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):

View File

@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum):
GEMMA2 = auto() GEMMA2 = auto()
GEMMA3 = auto() GEMMA3 = auto()
GEMMA3N = auto() GEMMA3N = auto()
GEMMA_EMBEDDING = auto()
STARCODER2 = auto() STARCODER2 = auto()
RWKV6 = auto() RWKV6 = auto()
RWKV6QWEN2 = auto() RWKV6QWEN2 = auto()
@ -674,6 +675,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@ -1719,6 +1721,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.LAUREL_R, MODEL_TENSOR.LAUREL_R,
MODEL_TENSOR.LAUREL_POST_NORM, MODEL_TENSOR.LAUREL_POST_NORM,
], ],
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.STARCODER2: [ MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View File

@ -14,6 +14,7 @@ class TensorNameMap:
"transformer.word_embeddings", # falcon "transformer.word_embeddings", # falcon
"word_embeddings", # bloom "word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
"embed_tokens", # embeddinggemma
"tok_embeddings", # llama-pth "tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert "embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
@ -141,6 +142,7 @@ class TensorNameMap:
"rwkv.blocks.{bid}.ln1", # rwkv6 "rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7 "model.layers.{bid}.ln1", # rwkv7
"model.layers.{bid}.input_layernorm", # llama4 "model.layers.{bid}.input_layernorm", # llama4
"layers.{bid}.input_layernorm", # embeddinggemma
"transformer_encoder.{bid}.attention_norm", # neobert "transformer_encoder.{bid}.attention_norm", # neobert
"model.layers.{bid}.operator_norm", # lfm2 "model.layers.{bid}.operator_norm", # lfm2
"model.transformer.blocks.{bid}.attn_norm", # llada "model.transformer.blocks.{bid}.attn_norm", # llada
@ -179,6 +181,7 @@ class TensorNameMap:
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
"layers.{bid}.self_attn.q_proj", # embeddinggemma
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
@ -197,6 +200,7 @@ class TensorNameMap:
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
"layers.{bid}.self_attn.k_proj", # embeddinggemma
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
@ -216,6 +220,7 @@ class TensorNameMap:
# Attention value # Attention value
MODEL_TENSOR.ATTN_V: ( MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
"layers.{bid}.self_attn.v_proj", # embeddinggemma
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert "encoder.layer.{bid}.attention.self.value", # bert
"transformer.layer.{bid}.attention.v_lin", # distillbert "transformer.layer.{bid}.attention.v_lin", # distillbert
@ -239,6 +244,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.dense", # falcon "transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom "h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
"layers.{bid}.self_attn.o_proj", # embeddinggemma
"model.layers.{bid}.self_attn.out_proj", # lfm2 "model.layers.{bid}.self_attn.out_proj", # lfm2
"model.layers.{bid}.self_attn.linear_attn", # deci "model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth "layers.{bid}.attention.wo", # llama-pth
@ -277,6 +283,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_POST_NORM: ( MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
"layers.{bid}.post_attention_layernorm", # embeddinggemma
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2 "model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
), ),
@ -320,12 +327,14 @@ class TensorNameMap:
# Post feed-forward norm # Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: ( MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2 "model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.pre_ff_layernorm.weight", "model.layers.{bid}.pre_ff_layernorm.weight",
), ),
# Post feed-forward norm # Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: ( MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2 "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
"model.layers.{bid}.feed_forward.up_proj", "model.layers.{bid}.feed_forward.up_proj",
@ -362,6 +371,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom "h.{bid}.mlp.dense_h_to_4h", # bloom
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
"layers.{bid}.mlp.up_proj", # embeddinggemma
"layers.{bid}.feed_forward.w3", # llama-pth "layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert "encoder.layer.{bid}.intermediate.dense", # bert
"transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.layer.{bid}.ffn.lin1", # distillbert
@ -421,6 +431,7 @@ class TensorNameMap:
# Feed-forward gate # Feed-forward gate
MODEL_TENSOR.FFN_GATE: ( MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
"layers.{bid}.mlp.gate_proj", # embeddinggemma
"layers.{bid}.feed_forward.w1", # llama-pth "layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais "transformer.h.{bid}.mlp.c_fc2", # jais
@ -461,6 +472,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom "h.{bid}.mlp.dense_4h_to_h", # bloom
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
"layers.{bid}.mlp.down_proj", # embeddinggemma
"layers.{bid}.feed_forward.w2", # llama-pth "layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert "encoder.layer.{bid}.output.dense", # bert
"transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.layer.{bid}.ffn.lin2", # distillbert
@ -513,6 +525,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan "model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.q_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.q_ln", # sea-lion "transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm "transformer.layers.{bid}.attn.q_norm", # openelm
@ -525,6 +538,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan "model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.k_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.k_ln", # sea-lion "transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm "transformer.layers.{bid}.attn.k_norm", # openelm

View File

@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_MAMBA2, "mamba2" },
@ -1038,6 +1039,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
}, },
}, },
{
LLM_ARCH_GEMMA_EMBEDDING,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{ {
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
{ {

View File

@ -49,6 +49,7 @@ enum llm_arch {
LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2, LLM_ARCH_MAMBA2,

View File

@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
} }
} }
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
LLAMA_LOG_DEBUG(" ");
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
LLAMA_LOG_DEBUG("%2d", j);
}
LLAMA_LOG_DEBUG("\n");
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
LLAMA_LOG_DEBUG(" %2d ", i);
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
float val = data[i * n_kv + j];
if (val == -INFINITY) {
LLAMA_LOG_DEBUG("");
} else {
LLAMA_LOG_DEBUG(" 0");
}
}
LLAMA_LOG_DEBUG("\n");
}
}
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens; const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
@ -277,21 +307,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0]; const llama_seq_id s0 = ubatch->seq_id[i0][0];
// TODO: reimplement this like in llama_kv_cache if (s0 != s1) {
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) { continue; // skip different sequences
if (hparams.use_alibi) { }
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else { if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
f = 0.0f; continue; // skip future tokens for causal attention
} }
break;
if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
continue; // skip masked tokens for SWA
}
// TODO: reimplement this like in llama_kv_cache_unified
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else {
f = 0.0f;
} }
} }
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
} }
} }
} }
if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
}
} }
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {

View File

@ -78,6 +78,11 @@ struct llm_graph_params;
class llm_graph_input_i { class llm_graph_input_i {
public: public:
llm_graph_input_i() {
const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
}
virtual ~llm_graph_input_i() = default; virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0; virtual void set_input(const llama_ubatch * ubatch) = 0;
@ -90,6 +95,9 @@ public:
GGML_UNUSED(params); GGML_UNUSED(params);
return false; return false;
} }
protected:
// env: LLAMA_GRAPH_INPUT_DEBUG
int debug = 0;
}; };
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>; using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;

View File

@ -1,6 +1,7 @@
#include "llama-hparams.h" #include "llama-hparams.h"
#include "ggml.h" #include "ggml.h"
#include <cassert>
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) { if (dense_first) {
@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const {
return res; return res;
} }
bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}

View File

@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
}; };
enum llama_swa_type { enum llama_swa_type {
LLAMA_SWA_TYPE_NONE = 0, LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1, LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2, LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_SYMMETRIC = 3,
}; };
struct llama_hparams_posnet { struct llama_hparams_posnet {
@ -227,6 +228,8 @@ struct llama_hparams {
// number of layers for which has_kv() returns true // number of layers for which has_kv() returns true
uint32_t n_layer_kv() const; uint32_t n_layer_kv() const;
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
}; };
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable"); static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

View File

@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
kv_base = std::make_unique<llama_kv_cache>( kv_base = std::make_unique<llama_kv_cache>(
model, type_k, type_v, model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad, v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); 0, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>( kv_swa = std::make_unique<llama_kv_cache>(
model, type_k, type_v, model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad, v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse); hparams.n_swa, filter_swa, reuse);
} }
void llama_kv_cache_iswa::clear(bool data) { void llama_kv_cache_iswa::clear(bool data) {

View File

@ -27,11 +27,10 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter, const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans), model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
GGML_ASSERT(kv_size % n_pad == 0); GGML_ASSERT(kv_size % n_pad == 0);
@ -1393,29 +1392,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
} }
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
assert(p0 >= 0 && p1 >= 0); return hparams.is_masked_swa(p0, p1);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
}
return false;
} }
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {

View File

@ -89,7 +89,6 @@ public:
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter, const layer_filter_cb & filter,
const layer_reuse_cb & reuse); const layer_reuse_cb & reuse);
@ -212,8 +211,6 @@ private:
// env: LLAMA_KV_CACHE_DEBUG // env: LLAMA_KV_CACHE_DEBUG
int debug = 0; int debug = 0;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;

View File

@ -17,7 +17,6 @@ llama_memory_hybrid::llama_memory_hybrid(
uint32_t kv_size, uint32_t kv_size,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */ /* recurrent */
ggml_type type_r, ggml_type type_r,
ggml_type type_s, ggml_type type_s,
@ -41,7 +40,6 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max, n_seq_max,
n_pad, n_pad,
n_swa, n_swa,
swa_type,
filter_attn == nullptr ? filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); } [&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn, : filter_attn,

View File

@ -27,7 +27,6 @@ public:
uint32_t kv_size, uint32_t kv_size,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */ /* recurrent */
ggml_type type_r, ggml_type type_r,
ggml_type type_s, ggml_type type_s,

View File

@ -1142,6 +1142,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_GEMMA_EMBEDDING:
{
hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
hparams.set_swa_pattern(6);
hparams.causal_attn = false; // embeddings do not use causal attention
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
switch (hparams.n_layer) {
case 24: type = LLM_TYPE_0_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -3484,6 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} }
} break; } break;
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA_EMBEDDING:
{ {
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -11045,6 +11066,136 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
} }
}; };
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) {
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = build_norm(cur,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
cur = build_norm(sa_out,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// feed-forward network
{
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", -1);
cur = ggml_add(ctx0, cur, sa_out);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}
};
// TODO: move up next to build_starcoder // TODO: move up next to build_starcoder
struct llm_build_starcoder2 : public llm_graph_context { struct llm_build_starcoder2 : public llm_graph_context {
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
@ -18481,6 +18632,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_DREAM: case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA: case LLM_ARCH_LLADA:
{ {
@ -18529,7 +18681,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_kv_size */ cparams.n_ctx, /* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding, /* attn_n_pad */ padding,
/* attn_n_swa */ hparams.n_swa, /* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
@ -18599,7 +18750,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max, cparams.n_seq_max,
padding, padding,
hparams.n_swa, hparams.n_swa,
hparams.swa_type,
nullptr, nullptr,
nullptr); nullptr);
} }
@ -18761,6 +18911,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params); llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
} break; } break;
case LLM_ARCH_GEMMA_EMBEDDING:
{
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
} break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {
llm = std::make_unique<llm_build_starcoder2>(*this, params); llm = std::make_unique<llm_build_starcoder2>(*this, params);
@ -19161,6 +19315,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA3N:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM: case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX: case LLM_ARCH_GPTNEOX: