Support of Sparse Attention added

This commit is contained in:
Pavel Rykov 2026-03-29 21:35:26 +03:00
parent ec16a072f0
commit 1c6a5e61a6
13 changed files with 527 additions and 1 deletions

View File

@ -5223,10 +5223,31 @@ class GPT2Model(TextModel):
@ModelBase.register("RuGPT3XLForCausalLM")
class RuGPT3XLModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
model_arch = gguf.MODEL_ARCH.RUGPT3XL
_qkv_parts: list[dict[str, Tensor]] | None = None
def set_gguf_parameters(self):
super().set_gguf_parameters()
sparse_mode = self.hparams.get("sparse_mode", "none")
if sparse_mode == "alternating":
self.gguf_writer.add_uint32(
gguf.Keys.Attention.SPARSE_BLOCK_SIZE.format(arch=self.gguf_writer.arch),
self.hparams.get("sparse_block_size", 16),
)
self.gguf_writer.add_uint32(
gguf.Keys.Attention.SPARSE_NUM_LOCAL_BLOCKS.format(arch=self.gguf_writer.arch),
self.hparams.get("sparse_num_local_blocks", 8),
)
self.gguf_writer.add_uint32(
gguf.Keys.Attention.SPARSE_NUM_GLOBAL_BLOCKS.format(arch=self.gguf_writer.arch),
self.hparams.get("sparse_num_global_blocks", 1),
)
self.gguf_writer.add_uint32(
gguf.Keys.Attention.SPARSE_NUM_GLOBAL_PATTERNS.format(arch=self.gguf_writer.arch),
self.hparams.get("sparse_num_different_global_patterns", 8),
)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Fuse separate Q, K, V projections into a single QKV tensor
if ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name:

View File

@ -183,6 +183,10 @@ class Keys:
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
SPARSE_BLOCK_SIZE = "{arch}.attention.sparse_block_size"
SPARSE_NUM_LOCAL_BLOCKS = "{arch}.attention.sparse_num_local_blocks"
SPARSE_NUM_GLOBAL_BLOCKS = "{arch}.attention.sparse_num_global_blocks"
SPARSE_NUM_GLOBAL_PATTERNS = "{arch}.attention.sparse_num_global_patterns"
class Indexer:
HEAD_COUNT = "{arch}.attention.indexer.head_count"
@ -493,6 +497,7 @@ class MODEL_ARCH(IntEnum):
LLAMA_EMBED = auto()
MAINCODER = auto()
KIMI_LINEAR = auto()
RUGPT3XL = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -957,6 +962,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
MODEL_ARCH.RUGPT3XL: "rugpt3xl",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -2059,6 +2065,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.RUGPT3XL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PHI2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -132,6 +132,7 @@ add_library(llama
models/qwen3vl.cpp
models/refact.cpp
models/rnd1.cpp
models/rugpt3xl.cpp
models/rwkv6-base.cpp
models/rwkv6.cpp
models/rwkv6qwen2.cpp

View File

@ -131,6 +131,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_MAINCODER, "maincoder" },
{ LLM_ARCH_KIMI_LINEAR, "kimi-linear" },
{ LLM_ARCH_RUGPT3XL, "rugpt3xl" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -238,6 +239,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ATTENTION_SPARSE_BLOCK_SIZE, "%s.attention.sparse_block_size" },
{ LLM_KV_ATTENTION_SPARSE_NUM_LOCAL_BLOCKS, "%s.attention.sparse_num_local_blocks" },
{ LLM_KV_ATTENTION_SPARSE_NUM_GLOBAL_BLOCKS, "%s.attention.sparse_num_global_blocks" },
{ LLM_KV_ATTENTION_SPARSE_NUM_GLOBAL_PATTERNS, "%s.attention.sparse_num_global_patterns" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
@ -721,6 +726,19 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_DOWN,
};
case LLM_ARCH_RUGPT3XL:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_DOWN,
};
case LLM_ARCH_GPTNEOX:
return {
LLM_TENSOR_TOKEN_EMBD,

View File

@ -135,6 +135,7 @@ enum llm_arch {
LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_MAINCODER,
LLM_ARCH_KIMI_LINEAR,
LLM_ARCH_RUGPT3XL,
LLM_ARCH_UNKNOWN,
};
@ -242,6 +243,10 @@ enum llm_kv {
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ATTENTION_SPARSE_BLOCK_SIZE,
LLM_KV_ATTENTION_SPARSE_NUM_LOCAL_BLOCKS,
LLM_KV_ATTENTION_SPARSE_NUM_GLOBAL_BLOCKS,
LLM_KV_ATTENTION_SPARSE_NUM_GLOBAL_PATTERNS,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT_SWA,

View File

@ -446,6 +446,24 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_attn_kv_sparse::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->set_input_kq_mask_sparse(self_kq_mask_sparse, ubatch, hparams);
}
bool llm_graph_input_attn_kv_sparse::can_reuse(const llm_graph_params & params) {
const auto * mctx_new = static_cast<const llama_kv_cache_context *>(params.mctx);
this->mctx = mctx_new;
bool good = true;
good &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
good &= can_reuse_kq_mask(self_kq_mask, mctx_new, params.ubatch, params.cparams);
return good;
}
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
@ -2076,6 +2094,146 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_sparse * llm_graph_context::build_attn_inp_kv_sparse() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_sparse>(hparams, cparams, mctx_cur);
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE);
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
// standard causal mask [n_kv, n_tokens, 1, n_stream]
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
// per-head sparse mask [n_kv, n_tokens, n_head, n_stream]
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
const auto n_head_q = hparams.n_head();
inp->self_kq_mask_sparse = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, n_head_q, n_stream);
ggml_set_input(inp->self_kq_mask_sparse);
ggml_set_name(inp->self_kq_mask_sparse, "kq_mask_sparse");
return (llm_graph_input_attn_kv_sparse *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn_dense(
llm_graph_input_attn_kv_sparse * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
float kq_scale,
int il) const {
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);
const auto * mctx_cur = inp->mctx;
{
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(q, k, v, nullptr, kq_mask, nullptr, nullptr, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
ggml_tensor * llm_graph_context::build_attn_sparse(
llm_graph_input_attn_kv_sparse * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
float kq_scale,
int il) const {
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);
const auto * mctx_cur = inp->mctx;
{
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
const auto & kq_mask_sparse = inp->get_kq_mask_sparse();
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
const bool v_trans = v->nb[1] > v->nb[2];
const auto n_stream_val = k->ne[3];
q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream_val, n_stream_val,
q->nb[1], q->nb[2], q->nb[3]/n_stream_val, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
// QK^T: [n_kv, n_tokens, n_head, n_stream]
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
cb(kq, "kq", il);
// add per-head sparse mask (0 or -INF), same shape as kq
kq = ggml_add(ctx0, kq, kq_mask_sparse);
cb(kq, "kq_sparse_masked", il);
// apply scale + causal mask + softmax via standard path
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
cb(kq, "kq_soft_max_ext", il);
ggml_tensor * kqv;
if (v_trans) {
kqv = ggml_mul_mat(ctx0, v, kq);
} else {
kqv = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v)), kq);
}
cb(kqv, "kqv", il);
ggml_tensor * cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,

View File

@ -317,6 +317,42 @@ public:
const llama_kv_cache_context * mctx;
};
// KV cache input with additional per-head sparse attention mask (ruGPT3XL)
class llm_graph_input_attn_kv_sparse : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_sparse(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_sparse() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_sparse() const { return self_kq_mask_sparse; }
ggml_tensor * self_k_idxs = nullptr;
ggml_tensor * self_v_idxs = nullptr;
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr;
ggml_tensor * self_kq_mask_sparse = nullptr; // F32 [n_kv, n_batch/n_stream, n_head, n_stream]
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_context * mctx;
};
// V-less input for the KV cache
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
class llm_graph_input_attn_k : public llm_graph_input_i {
@ -906,6 +942,28 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_sparse * build_attn_inp_kv_sparse() const;
ggml_tensor * build_attn_sparse(
llm_graph_input_attn_kv_sparse * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
float kq_scale,
int il) const;
ggml_tensor * build_attn_dense(
llm_graph_input_attn_kv_sparse * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
float kq_scale,
int il) const;
llm_graph_input_attn_k * build_attn_inp_k() const;
ggml_tensor * build_attn(

View File

@ -206,6 +206,12 @@ struct llama_hparams {
uint32_t indexer_head_size = 0;
uint32_t indexer_top_k = 0;
// ruGPT3XL block-sparse attention (DeepSpeed FixedSparsityConfig)
uint32_t sparse_block_size = 0;
uint32_t sparse_num_local_blocks = 0;
uint32_t sparse_num_global_blocks = 0;
uint32_t sparse_num_global_patterns = 0;
// qwen3vl deepstack
uint32_t n_deepstack_layers = 0;

View File

@ -1482,6 +1482,106 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
}
void llama_kv_cache::set_input_kq_mask_sparse(ggml_tensor * dst, const llama_ubatch * ubatch, const llama_hparams & hp) const {
const uint32_t block_size = hp.sparse_block_size;
const uint32_t num_local_blocks = hp.sparse_num_local_blocks;
const uint32_t num_global_blocks = hp.sparse_num_global_blocks;
const uint32_t num_global_patterns = hp.sparse_num_global_patterns;
const uint32_t n_head_val = hp.n_head();
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
if (block_size == 0 || num_local_blocks == 0) {
memset(data, 0, ggml_nbytes(dst));
return;
}
// dst shape: [n_kv, n_tokens_per_stream, n_head, n_stream]
const int64_t n_kv_t = dst->ne[0];
const int64_t n_tps = dst->ne[1];
const int64_t n_head_t = dst->ne[2];
const int64_t n_strm = dst->ne[3];
GGML_ASSERT(n_head_t == (int64_t) n_head_val);
GGML_ASSERT(n_strm == (int64_t) n_stream);
const uint32_t total_blocks = hp.n_ctx_train / block_size;
const uint32_t regular_end = total_blocks - (total_blocks % num_local_blocks);
for (int64_t s = 0; s < n_strm; ++s) {
const auto & cells = v_cells[seq_to_stream.empty() ? 0 : s];
for (int64_t it = 0; it < n_tps; ++it) {
const uint32_t i_token = (uint32_t)(s * n_tps + it);
if (i_token >= ubatch->n_tokens) break;
const llama_pos p1 = ubatch->pos[i_token];
const llama_seq_id seq_id = ubatch->seq_id[i_token][0];
const uint32_t q_blk = (uint32_t) p1 / block_size;
const uint32_t q_win = q_blk / num_local_blocks;
for (int64_t h = 0; h < n_head_t; ++h) {
const int32_t first_global = (int32_t) num_local_blocks
- (int32_t)(1 + h % num_global_patterns) * (int32_t) num_global_blocks;
// offset in the flat data array
// layout: [n_kv, n_tps, n_head, n_stream]
// idx = s*(n_head*n_tps*n_kv) + h*(n_tps*n_kv) + it*n_kv + j
const int64_t base = s * (n_head_t * n_tps * n_kv_t)
+ h * (n_tps * n_kv_t)
+ it * n_kv_t;
for (int64_t j = 0; j < n_kv_t; ++j) {
if (cells.is_empty(j) || !cells.seq_has(j, seq_id)) {
data[base + j] = -INFINITY;
continue;
}
const llama_pos p0 = cells.pos_get(j);
if (p0 > p1) {
data[base + j] = -INFINITY;
continue;
}
const uint32_t k_blk = (uint32_t) p0 / block_size;
const uint32_t k_win = k_blk / num_local_blocks;
// local window check
if (q_win == k_win && k_blk <= q_blk) {
data[base + j] = 0.0f;
continue;
}
// global block check
bool global = false;
for (int32_t gi = first_global; gi < (int32_t) regular_end; gi += (int32_t) num_local_blocks) {
if (gi < 0) continue;
if (k_blk >= (uint32_t) gi && k_blk < (uint32_t) gi + num_global_blocks && q_blk >= k_blk) {
global = true;
break;
}
}
if (!global && regular_end < total_blocks) {
int32_t tail = (int32_t)(regular_end + first_global);
if (tail > (int32_t)(total_blocks - num_global_blocks)) {
tail = (int32_t)(total_blocks - num_global_blocks);
}
if (tail >= 0 && k_blk >= (uint32_t) tail
&& k_blk < (uint32_t) tail + num_global_blocks
&& q_blk >= k_blk) {
global = true;
}
}
data[base + j] = global ? 0.0f : -INFINITY;
}
}
}
}
}
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
const int64_t n_tokens = ubatch->n_tokens;
@ -2279,6 +2379,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
kv->set_input_kq_mask(dst, ubatch, causal_attn);
}
void llama_kv_cache_context::set_input_kq_mask_sparse(
ggml_tensor * dst, const llama_ubatch * ubatch, const llama_hparams & hp) const {
kv->set_input_kq_mask_sparse(dst, ubatch, hp);
}
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}

View File

@ -197,6 +197,7 @@ public:
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_kq_mask_sparse(ggml_tensor * dst, const llama_ubatch * ubatch, const llama_hparams & hp) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
@ -352,6 +353,7 @@ public:
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_kq_mask_sparse(ggml_tensor * dst, const llama_ubatch * ubatch, const llama_hparams & hp) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:

View File

@ -1143,6 +1143,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_RUGPT3XL:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_SPARSE_BLOCK_SIZE, hparams.sparse_block_size, false);
ml.get_key(LLM_KV_ATTENTION_SPARSE_NUM_LOCAL_BLOCKS, hparams.sparse_num_local_blocks, false);
ml.get_key(LLM_KV_ATTENTION_SPARSE_NUM_GLOBAL_BLOCKS, hparams.sparse_num_global_blocks, false);
ml.get_key(LLM_KV_ATTENTION_SPARSE_NUM_GLOBAL_PATTERNS, hparams.sparse_num_global_patterns, false);
switch (hparams.n_layer) {
case 24: type = LLM_TYPE_XL; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_CODESHELL:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -3960,6 +3972,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_GPT2:
case LLM_ARCH_RUGPT3XL:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
@ -8450,6 +8463,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_gpt2>(*this, params);
} break;
case LLM_ARCH_RUGPT3XL:
{
llm = std::make_unique<llm_build_rugpt3xl>(*this, params);
} break;
case LLM_ARCH_CODESHELL:
{
llm = std::make_unique<llm_build_codeshell>(*this, params);
@ -8916,6 +8933,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
// these models do not use RoPE
case LLM_ARCH_CLIP:
case LLM_ARCH_GPT2:
case LLM_ARCH_RUGPT3XL:
case LLM_ARCH_GPTJ:
case LLM_ARCH_MPT:
case LLM_ARCH_REFACT:

View File

@ -646,6 +646,10 @@ struct llm_build_rnd1 : public llm_graph_context {
llm_build_rnd1(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_rugpt3xl : public llm_graph_context {
llm_build_rugpt3xl(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_rwkv6 : public llm_build_rwkv6_base {
llm_build_rwkv6(const llama_model & model, const llm_graph_params & params);
};

112
src/models/rugpt3xl.cpp Normal file
View File

@ -0,0 +1,112 @@
#include "models.h"
llm_build_rugpt3xl::llm_build_rugpt3xl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const bool has_sparse = hparams.sparse_block_size > 0;
ggml_tensor * cur;
ggml_tensor * pos;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_sparse();
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
cb(pos, "pos_embd", -1);
inpL = ggml_add(ctx0, inpL, pos);
cb(inpL, "inpL", -1);
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
cur = build_norm(inpL,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
LLM_NORM, il);
cb(cur, "attn_norm", il);
// self-attention
{
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
const bool is_sparse_layer = has_sparse && (il % 2 == 0);
if (is_sparse_layer) {
cur = build_attn_sparse(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, 1.0f/sqrtf(float(n_embd_head)), il);
} else {
cur = build_attn_dense(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, 1.0f/sqrtf(float(n_embd_head)), il);
}
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
// FF
{
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm,
model.layers[il].ffn_norm_b,
LLM_NORM, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = build_norm(inpL,
model.output_norm,
model.output_norm_b,
LLM_NORM, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}