Merge branch 'master' into add-fh1-rebased
This commit is contained in:
commit
d28c31a90c
|
|
@ -342,7 +342,7 @@ jobs:
|
|||
cd build
|
||||
export GGML_VK_VISIBLE_DEVICES=0
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 3600
|
||||
ctest -L main --verbose --timeout 4200
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
|
|
@ -815,6 +815,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||
res = "minerva-7b"
|
||||
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
|
||||
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
|
||||
res = "hunyuan"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
|
@ -6666,6 +6669,156 @@ class FalconH1Model(Mamba2Model):
|
|||
# Add any other Falcon Mamba2 specific configuration
|
||||
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanMoEV1ForCausalLM")
|
||||
class HunYuanMoEModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# For handling tied embeddings
|
||||
self._tok_embd = None
|
||||
|
||||
def set_vocab(self):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
# 1. Get the pre-tokenizer identifier hash
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
|
||||
if len(merged) == 2: # todo this is an assert in Qwen, why?
|
||||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||
|
||||
# 3. Generate the tokens and toktypes lists
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
assert tokenizer.vocab_size == vocab_size
|
||||
special_tokens = tokenizer.special_tokens
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token = reverse_vocab[i]
|
||||
tokens.append(token)
|
||||
if i in special_tokens.values():
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
# 4. Write all vocab-related fields to the GGUF writer
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_token_merges(merges)
|
||||
|
||||
# 5. Add special tokens and chat templates
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
# FIX for BOS token: Overwrite incorrect id read from config.json
|
||||
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
|
||||
|
||||
moe_intermediate_size = hparams["moe_intermediate_size"]
|
||||
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
|
||||
|
||||
moe_topk = hparams["moe_topk"]
|
||||
assert all(topk == moe_topk[0] for topk in moe_topk)
|
||||
self.gguf_writer.add_expert_used_count(moe_topk[0])
|
||||
|
||||
moe_shared_expert = hparams["num_shared_expert"]
|
||||
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
|
||||
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
|
||||
|
||||
# Rope
|
||||
rope_scaling = hparams.get("rope_scaling", {})
|
||||
if rope_scaling.get("type") == "dynamic":
|
||||
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
|
||||
alpha = rope_scaling.get("alpha", 1000)
|
||||
base = hparams.get("rope_theta", 10000.0)
|
||||
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
|
||||
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
|
||||
self.gguf_writer.add_rope_freq_base(scaled_base)
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_rope_scaling_factor(1)
|
||||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||
|
||||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "model.embed_tokens.weight":
|
||||
self._tok_embd = data_torch.clone()
|
||||
|
||||
if name == "lm_head.weight":
|
||||
if self.hparams.get("tie_word_embeddings", False):
|
||||
logger.info("Skipping tied output layer 'lm_head.weight'")
|
||||
return []
|
||||
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
if self._experts is not None:
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ pre_computed_hashes = [
|
|||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -136,6 +136,11 @@ static bool run(llama_context * ctx, const common_params & params) {
|
|||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
if (tokens.empty()) {
|
||||
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -495,7 +495,7 @@ extern "C" {
|
|||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_POOL_2D_BACK,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_UPSCALE,
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_PAD_REFLECT_1D,
|
||||
GGML_OP_ROLL,
|
||||
|
|
@ -557,6 +557,8 @@ extern "C" {
|
|||
GGML_GLU_OP_REGLU,
|
||||
GGML_GLU_OP_GEGLU,
|
||||
GGML_GLU_OP_SWIGLU,
|
||||
GGML_GLU_OP_GEGLU_ERF,
|
||||
GGML_GLU_OP_GEGLU_QUICK,
|
||||
|
||||
GGML_GLU_OP_COUNT,
|
||||
};
|
||||
|
|
@ -1147,6 +1149,22 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_geglu_erf(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_geglu_quick(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// A: n columns, r rows,
|
||||
// B: n columns, r rows,
|
||||
GGML_API struct ggml_tensor * ggml_glu_split(
|
||||
|
|
@ -1170,6 +1188,16 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// normalize along rows
|
||||
GGML_API struct ggml_tensor * ggml_norm(
|
||||
struct ggml_context * ctx,
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@
|
|||
#include <aclnnop/aclnn_pow.h>
|
||||
#include <aclnnop/aclnn_grouped_matmul_v3.h>
|
||||
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
|
||||
#include <aclnnop/aclnn_zero.h>
|
||||
#include <float.h>
|
||||
|
||||
#include <cmath>
|
||||
|
|
@ -804,10 +805,11 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
|
|||
nb[i] = nb[i - 1] * ne[i - 1];
|
||||
}
|
||||
|
||||
ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
|
||||
aclTensor* zero =
|
||||
ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero);
|
||||
return zero;
|
||||
GGML_UNUSED(n_bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_geglu_erf
|
||||
|
||||
static void ggml_compute_forward_geglu_erf_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
char * src0_d = (char *) src0->data;
|
||||
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||
const size_t src0_o = src0->nb[1];
|
||||
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_geglu_erf_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
char * src0_d = (char *) src0->data;
|
||||
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||
const size_t src0_o = src0->nb[1];
|
||||
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_FP16_TO_FP32(x);
|
||||
GGML_UNUSED(v);
|
||||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_geglu_erf(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_geglu_erf_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_geglu_erf_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_geglu_quick
|
||||
|
||||
static void ggml_compute_forward_geglu_quick_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
char * src0_d = (char *) src0->data;
|
||||
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||
const size_t src0_o = src0->nb[1];
|
||||
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_geglu_quick_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
char * src0_d = (char *) src0->data;
|
||||
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||
const size_t src0_o = src0->nb[1];
|
||||
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_FP16_TO_FP32(x);
|
||||
GGML_UNUSED(v);
|
||||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_geglu_quick(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_geglu_quick_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_geglu_quick_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_norm
|
||||
|
||||
static void ggml_compute_forward_norm_f32(
|
||||
|
|
@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
|
|||
{
|
||||
ggml_compute_forward_swiglu(params, dst);
|
||||
} break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
{
|
||||
ggml_compute_forward_geglu_erf(params, dst);
|
||||
} break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
{
|
||||
ggml_compute_forward_geglu_quick(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
|
|||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float xi = x[i];
|
||||
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
|
||||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_GELU_QUICK_FP16
|
||||
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
||||
uint16_t t;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
|
||||
memcpy(&t, &fp16, sizeof(uint16_t));
|
||||
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
|
||||
}
|
||||
}
|
||||
#else
|
||||
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||
const uint16_t * i16 = (const uint16_t *) x;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float v = GGML_CPU_FP16_TO_FP32(g[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
|
||||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||
#ifndef GGML_USE_ACCELERATE
|
||||
ggml_float sum = 0.0;
|
||||
|
|
|
|||
|
|
@ -176,17 +176,20 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#endif
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
shared_memory_limit_raised[id] = true; \
|
||||
} \
|
||||
} while (0)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
shared_memory_limit_raised[id] = true; \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
GGML_UNUSED(nbytes); \
|
||||
} while (0)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
|
|
|
|||
|
|
@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
|
|
|||
|
|
@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|||
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
|
|
@ -210,6 +214,10 @@ void get_rows_cuda(
|
|||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
|
|
|
|||
|
|
@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_GLU_OP_SWIGLU:
|
||||
ggml_cuda_op_swiglu(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
ggml_cuda_op_geglu_erf(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
ggml_cuda_op_geglu_quick(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous_1(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -3192,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
|
@ -3365,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
|
|
|||
|
|
@ -50,21 +50,19 @@ static __global__ void rope_norm(
|
|||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + 0] = x[ix + 0];
|
||||
dst[idst + 1] = x[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
|
@ -94,21 +92,19 @@ static __global__ void rope_neox(
|
|||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
|
@ -138,21 +134,19 @@ static __global__ void rope_multi(
|
|||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
|
|
|||
|
|
@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
||||
}
|
||||
|
||||
/* silu_back */
|
||||
|
||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||
|
|
|
|||
|
|
@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
|
|||
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
|
||||
}
|
||||
|
||||
static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
int y0_src = (int)floorf(y_src_f);
|
||||
int y1_src = y0_src + 1;
|
||||
|
||||
y0_src = max(0, min(y0_src, ne01_src - 1));
|
||||
y1_src = max(0, min(y1_src, ne01_src - 1));
|
||||
|
||||
float dy = y_src_f - (float)y0_src;
|
||||
dy = max(0.0f, min(dy, 1.0f));
|
||||
|
||||
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
int x0_src = (int)floorf(x_src_f);
|
||||
int x1_src = x0_src + 1;
|
||||
|
||||
x0_src = max(0, min(x0_src, ne00_src - 1));
|
||||
x1_src = max(0, min(x1_src, ne00_src - 1));
|
||||
|
||||
float dx = x_src_f - (float)x0_src;
|
||||
dx = max(0.0f, min(dx, 1.0f));
|
||||
|
||||
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
|
||||
const float val_a = *p_a;
|
||||
const float val_b = *p_b;
|
||||
const float val_c = *p_c;
|
||||
const float val_d = *p_d;
|
||||
|
||||
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
||||
val_b * dx * (1.0f - dy) +
|
||||
val_c * (1.0f - dx) * dy +
|
||||
val_d * dx * dy;
|
||||
|
||||
dst[index] = result;
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
cudaStream_t stream) {
|
||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
}
|
||||
|
||||
static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset, cudaStream_t stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
|
@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const int mode_flags = dst->op_params[0];
|
||||
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
|
||||
|
||||
float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -530,6 +530,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_REGLU,
|
||||
GGML_METAL_KERNEL_TYPE_GEGLU,
|
||||
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
||||
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
|
|
@ -1510,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
|
|
@ -1693,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -2456,6 +2462,12 @@ static bool ggml_metal_encode_node(
|
|||
case GGML_GLU_OP_SWIGLU:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
|
|
@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|||
}
|
||||
|
||||
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
|
|
@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|||
}
|
||||
|
||||
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
|
|
@ -1258,6 +1261,50 @@ kernel void kernel_swiglu(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_geglu_erf(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = gelu_erf*x1;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_geglu_quick(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = gelu_quick*x1;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
|
|
|
|||
|
|
@ -398,12 +398,13 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_scale;
|
||||
cl_kernel kernel_silu, kernel_silu_4;
|
||||
cl_kernel kernel_gelu, kernel_gelu_4;
|
||||
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
|
||||
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
|
||||
cl_kernel kernel_relu;
|
||||
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
||||
cl_kernel kernel_clamp;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
|
||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
|
||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
||||
cl_kernel kernel_norm;
|
||||
cl_kernel kernel_rms_norm;
|
||||
cl_kernel kernel_group_norm;
|
||||
|
|
@ -736,6 +737,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
|
||||
CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
|
|
@ -753,12 +756,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
backend_ctx->program_glu =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
|
|
@ -2262,6 +2269,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
|
|
@ -2277,6 +2285,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -3864,6 +3874,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
UNUSED(src1);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
int n = ggml_nelements(dst);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
kernel = backend_ctx->kernel_gelu_erf_4;
|
||||
n /= 4;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_gelu_erf;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
size_t global_work_size[] = {(size_t)n, 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
|
|
@ -5763,19 +5811,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
|||
|
||||
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const cl_long nb01 = src0->nb[1];
|
||||
const cl_long nb02 = src0->nb[2];
|
||||
const cl_long nb03 = src0->nb[3];
|
||||
|
||||
const int ne12 = src1 ? src1->ne[2] : 0;
|
||||
const int ne13 = src1 ? src1->ne[3] : 0;
|
||||
|
||||
const cl_long nb11 = src1 ? src1->nb[1] : 0;
|
||||
const cl_long nb12 = src1 ? src1->nb[2] : 0;
|
||||
const cl_long nb13 = src1 ? src1->nb[3] : 0;
|
||||
|
||||
const cl_long nb1 = dst->nb[1];
|
||||
const cl_long nb2 = dst->nb[2];
|
||||
const cl_long nb3 = dst->nb[3];
|
||||
|
||||
float scale, max_bias;
|
||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
|
||||
const int nrows_x = ggml_nrows(src0);
|
||||
const int nrows_y = src0->ne[1];
|
||||
|
||||
const int n_head = nrows_x/nrows_y;
|
||||
const int n_head = src0->ne[2];
|
||||
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
|
|
@ -5820,13 +5880,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
|||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
|
@ -6233,6 +6302,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
kernel = backend_ctx->kernel_swiglu_f16;
|
||||
}
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_geglu_erf;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_geglu_erf_f16;
|
||||
}
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_geglu_quick;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_geglu_quick_f16;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported glu op");
|
||||
}
|
||||
|
|
@ -6347,6 +6430,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
|||
}
|
||||
func = ggml_cl_gelu;
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_gelu_erf;
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#define GELU_COEF_A 0.044715f
|
||||
#define GELU_QUICK_COEF -1.702f
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||||
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
||||
|
||||
kernel void kernel_gelu(
|
||||
global float * src0,
|
||||
|
|
@ -35,6 +36,32 @@ kernel void kernel_gelu_4(
|
|||
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_erf(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
float x = src0[get_global_id(0)];
|
||||
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_erf_4(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
float4 x = src0[get_global_id(0)];
|
||||
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define GELU_COEF_A 0.044715f
|
||||
#define GELU_QUICK_COEF -1.702f
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||||
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// geglu
|
||||
|
|
@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
|
|||
dst_row[i0] = silu*x1;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// geglu_erf
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_geglu_erf(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = gelu_erf*x1;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_geglu_erf_f16(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const half x0 = src0_row[i0];
|
||||
const half x1 = src1_row[i0];
|
||||
|
||||
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = gelu_erf*x1;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// geglu_quick
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_geglu_quick(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = gelu_quick*x1;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_geglu_quick_f16(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const half x0 = src0_row[i0];
|
||||
const half x1 = src1_row[i0];
|
||||
|
||||
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = gelu_quick*x1;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,32 +22,45 @@
|
|||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_4_f16(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global half * src1,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global half *)((global char *)src1 + offset1);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
int i13 = i03%ne13;
|
||||
int i12 = i02%ne12;
|
||||
int i11 = i01;
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
|
|
|
|||
|
|
@ -22,32 +22,45 @@
|
|||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_4(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
int i13 = i03%ne13;
|
||||
int i12 = i02%ne12;
|
||||
int i11 = i01;
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
|
|
|
|||
|
|
@ -22,32 +22,45 @@
|
|||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_f16(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global half * src1,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global half *)((global char *)src1 + offset1);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
|
||||
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
int i13 = i03%ne13;
|
||||
int i12 = i02%ne12;
|
||||
int i11 = i01;
|
||||
|
||||
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
|
|
|
|||
|
|
@ -22,32 +22,45 @@
|
|||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
|
||||
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
int i13 = i03%ne13;
|
||||
int i12 = i02%ne12;
|
||||
int i11 = i01;
|
||||
|
||||
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
|
|
|
|||
|
|
@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||
dst[i] = op_gelu_erf(x[j0]) * g[j1];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||
dst[i] = op_gelu_quick(x[j0]) * g[j1];
|
||||
}
|
||||
}
|
||||
|
||||
namespace ggml_sycl_detail {
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
|
|
@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
|||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
sycl_parallel_for(main_stream,
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
sycl_parallel_for(main_stream,
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
|
|
@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_swiglu(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_geglu_erf(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_geglu_quick(ctx, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|||
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
|
|
|||
|
|
@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_GLU_OP_SWIGLU:
|
||||
ggml_sycl_swiglu(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
ggml_sycl_geglu_erf(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
ggml_sycl_geglu_quick(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous_1(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row * ne0 + i0;
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||
return;
|
||||
}
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
|
||||
const int i = row * ne0 + i0;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
|
@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row * ne0 + i0;
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||
return;
|
||||
}
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
|
||||
const int i = row * ne0 + i0 / 2;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
|
@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
|||
}
|
||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
|
|
|||
|
|
@ -224,6 +224,21 @@ enum vk_device_architecture {
|
|||
INTEL_XE2,
|
||||
};
|
||||
|
||||
// HSK x HSV
|
||||
enum FaHeadSizes {
|
||||
FA_HEAD_SIZE_64,
|
||||
FA_HEAD_SIZE_80,
|
||||
FA_HEAD_SIZE_96,
|
||||
FA_HEAD_SIZE_112,
|
||||
FA_HEAD_SIZE_128,
|
||||
FA_HEAD_SIZE_192,
|
||||
FA_HEAD_SIZE_192_128,
|
||||
FA_HEAD_SIZE_256,
|
||||
FA_HEAD_SIZE_576_512,
|
||||
FA_HEAD_SIZE_UNSUPPORTED,
|
||||
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
||||
};
|
||||
|
||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||
|
||||
|
|
@ -441,6 +456,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_geglu[2];
|
||||
vk_pipeline pipeline_reglu[2];
|
||||
vk_pipeline pipeline_swiglu[2];
|
||||
vk_pipeline pipeline_geglu_erf[2];
|
||||
vk_pipeline pipeline_geglu_quick[2];
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
|
|
@ -467,26 +484,11 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
|
|
@ -499,6 +501,8 @@ struct vk_device_struct {
|
|||
|
||||
ggml_backend_buffer_type buffer_type;
|
||||
|
||||
bool disable_fusion;
|
||||
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
std::unique_ptr<vk_memory_logger> memory_logger;
|
||||
#endif
|
||||
|
|
@ -634,6 +638,7 @@ struct vk_flash_attn_push_constants {
|
|||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
uint32_t nem2;
|
||||
uint32_t nem3;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
|
|
@ -649,8 +654,7 @@ struct vk_flash_attn_push_constants {
|
|||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
uint32_t mask_n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
|
|
@ -1003,7 +1007,7 @@ struct ggml_backend_vk_context {
|
|||
|
||||
// number of additional consecutive nodes that are being fused with the
|
||||
// node currently being processed
|
||||
uint32_t num_additional_fused_ops {};
|
||||
int num_additional_fused_ops {};
|
||||
};
|
||||
|
||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||
|
|
@ -1089,8 +1093,8 @@ static size_t vk_skip_checks;
|
|||
static size_t vk_output_tensor;
|
||||
|
||||
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
|
||||
static void ggml_vk_check_results_0(ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_1(ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
||||
#endif
|
||||
|
||||
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
|
@ -1699,6 +1703,35 @@ enum FaCodePath {
|
|||
FA_COOPMAT2,
|
||||
};
|
||||
|
||||
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
||||
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
switch (hsk) {
|
||||
case 64: return FA_HEAD_SIZE_64;
|
||||
case 80: return FA_HEAD_SIZE_80;
|
||||
case 96: return FA_HEAD_SIZE_96;
|
||||
case 112: return FA_HEAD_SIZE_112;
|
||||
case 128: return FA_HEAD_SIZE_128;
|
||||
case 192:
|
||||
if (hsv == 192) {
|
||||
return FA_HEAD_SIZE_192;
|
||||
} else if (hsv == 128) {
|
||||
return FA_HEAD_SIZE_192_128;
|
||||
} else {
|
||||
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
case 256: return FA_HEAD_SIZE_256;
|
||||
case 576:
|
||||
if (hsv == 512) {
|
||||
return FA_HEAD_SIZE_576_512;
|
||||
} else {
|
||||
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
}
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
|
|
@ -1719,8 +1752,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
GGML_UNUSED(clamp);
|
||||
GGML_UNUSED(hsv);
|
||||
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
|
|
@ -1744,7 +1778,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
|||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || D == 256) {
|
||||
if (ggml_is_quantized(type) || hsk >= 256) {
|
||||
return {64, 32};
|
||||
}
|
||||
return {64, 64};
|
||||
|
|
@ -2037,19 +2071,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||
const uint32_t D = (hsk|hsv);
|
||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||
? scalar_flash_attention_workgroup_size
|
||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
|
|
@ -2058,26 +2094,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
|
|
@ -2786,6 +2825,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
CREATE_GLU(geglu)
|
||||
CREATE_GLU(reglu)
|
||||
CREATE_GLU(swiglu)
|
||||
CREATE_GLU(geglu_erf)
|
||||
CREATE_GLU(geglu_quick)
|
||||
#undef CREATE_GLU
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -3468,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
device->idx = idx;
|
||||
|
||||
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
|
||||
|
||||
return device;
|
||||
}
|
||||
|
||||
|
|
@ -3688,7 +3731,6 @@ static void ggml_vk_instance_init() {
|
|||
|
||||
}
|
||||
|
||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||
|
||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||
|
|
@ -6002,24 +6044,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||
|
||||
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||
|
||||
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
||||
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2;
|
||||
const uint32_t kshstride = hsk / 4 + 2;
|
||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||
|
||||
const uint32_t slope = Br * sizeof(float);
|
||||
|
|
@ -6027,7 +6092,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
|
@ -6050,12 +6115,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
||||
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
||||
|
||||
const uint32_t D = neq0;
|
||||
const uint32_t HSK = nek0;
|
||||
const uint32_t HSV = nev0;
|
||||
uint32_t N = neq1;
|
||||
const uint32_t KV = nek1;
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne0 == HSV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
|
|
@ -6063,12 +6130,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
GGML_ASSERT(neq0 == HSK);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
GGML_ASSERT(nev1 == nek1);
|
||||
|
||||
|
|
@ -6089,7 +6153,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
||||
|
||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||
path = FA_SCALAR;
|
||||
|
|
@ -6119,7 +6183,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
|
|
@ -6142,47 +6206,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||
if (path == FA_SCALAR &&
|
||||
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
||||
small_rows = true;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
|
|
@ -6212,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
|
|
@ -6224,7 +6266,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
|
|
@ -6311,17 +6353,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
|
||||
|
||||
const vk_flash_attn_push_constants pc = { N, KV,
|
||||
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
||||
(uint32_t)neq2, (uint32_t)neq3,
|
||||
(uint32_t)nek2, (uint32_t)nek3,
|
||||
(uint32_t)nev2, (uint32_t)nev3,
|
||||
nem1, nem2,
|
||||
nem1, nem2, nem3,
|
||||
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
||||
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
||||
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
||||
scale, max_bias, logit_softcap,
|
||||
mask != nullptr, n_head_log2, m0, m1,
|
||||
mask_n_head_log2, m0, m1,
|
||||
gqa_ratio, split_kv, split_k };
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
|
@ -6342,7 +6386,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
||||
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||
{
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
|
|
@ -6542,6 +6586,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -7610,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
|
@ -8841,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
||||
|
|
@ -8886,6 +8933,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -9100,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
// fused rms_norm + mul
|
||||
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
|
||||
} else {
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
|
|
@ -9133,6 +9182,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
break;
|
||||
default:
|
||||
|
|
@ -9260,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
ctx->compute_ctx.reset();
|
||||
|
||||
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
|
||||
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
|
||||
if (!ok) {
|
||||
if (node->op == GGML_OP_UNARY) {
|
||||
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
||||
|
|
@ -9275,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
||||
GGML_UNUSED(cgraph);
|
||||
ggml_backend_buffer * buf = nullptr;
|
||||
|
||||
switch (tensor->op) {
|
||||
|
|
@ -9351,6 +9403,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
default:
|
||||
|
|
@ -9383,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
// Only run if ctx hasn't been submitted yet
|
||||
if (!subctx->seqs.empty()) {
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
ggml_vk_check_results_0(tensor);
|
||||
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
|
||||
use_fence = true;
|
||||
#endif
|
||||
|
||||
|
|
@ -9403,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
ggml_vk_wait_for_fence(ctx);
|
||||
}
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
ggml_vk_check_results_1(tensor);
|
||||
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -9850,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
|
|||
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
// additional constraints specific to this fusion
|
||||
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||
|
||||
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||
// rms_norm only supports f32
|
||||
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
||||
mul->src[1]->type != GGML_TYPE_F32 ||
|
||||
mul->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
// if rms_norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] &&
|
||||
mul->src[0]->ne[1] != rms_norm->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
// rms_norm shader assumes contiguous rows
|
||||
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
|
|
@ -9863,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
|
||||
uint64_t total_mat_mul_bytes = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
}
|
||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||
|
|
@ -9933,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||
}
|
||||
|
||||
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
}
|
||||
|
||||
|
|
@ -10161,6 +10246,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
|
|
@ -10241,19 +10328,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
bool coopmat2 = device->coopmat2;
|
||||
switch (op->src[0]->ne[0]) {
|
||||
case 64:
|
||||
case 80:
|
||||
case 96:
|
||||
case 112:
|
||||
case 128:
|
||||
case 256:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
// different head sizes of K and V are not supported yet
|
||||
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
||||
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
|
|
@ -10265,12 +10341,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
|
||||
return false;
|
||||
}
|
||||
// It's straightforward to support different K/V dequant, but would
|
||||
// significantly increase the number of pipelines
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
|
@ -10725,11 +10795,21 @@ void * comp_result;
|
|||
size_t comp_size;
|
||||
size_t comp_nb[GGML_MAX_DIMS];
|
||||
size_t check_counter = 0;
|
||||
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
||||
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
||||
if (tensor->op == GGML_OP_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool fused_rms_norm_mul = false;
|
||||
int rms_norm_idx = -1;
|
||||
if (ctx->num_additional_fused_ops == 1 &&
|
||||
tensor->op == GGML_OP_RMS_NORM &&
|
||||
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
||||
fused_rms_norm_mul = true;
|
||||
tensor = cgraph->nodes[tensor_idx + 1];
|
||||
}
|
||||
|
||||
check_counter++;
|
||||
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
||||
return;
|
||||
|
|
@ -10757,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
ggml_tensor * srci = tensor->src[i];
|
||||
if (fused_rms_norm_mul) {
|
||||
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
||||
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
|
||||
switch (i) {
|
||||
case 0: srci = rms_norm->src[0]; break;
|
||||
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
|
||||
default: continue;
|
||||
}
|
||||
}
|
||||
if (srci == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -10814,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
} else if (tensor->op == GGML_OP_SUB) {
|
||||
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_MUL) {
|
||||
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
if (fused_rms_norm_mul) {
|
||||
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
|
||||
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
|
||||
} else {
|
||||
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
}
|
||||
} else if (tensor->op == GGML_OP_DIV) {
|
||||
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_CONCAT) {
|
||||
|
|
@ -11005,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
|
||||
ggml_build_forward_expand(cgraph, tensor_clone);
|
||||
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
|
||||
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
|
||||
|
||||
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
|
||||
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
|
||||
|
||||
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
||||
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
|
||||
|
|
@ -11031,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
|
||||
}
|
||||
|
||||
static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
||||
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
||||
if (tensor->op == GGML_OP_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
bool fused_rms_norm_mul = false;
|
||||
if (ctx->num_additional_fused_ops == 1 &&
|
||||
tensor->op == GGML_OP_RMS_NORM &&
|
||||
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
||||
fused_rms_norm_mul = true;
|
||||
tensor = cgraph->nodes[tensor_idx + 1];
|
||||
}
|
||||
|
||||
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@
|
|||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
|
@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
|
@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
|||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][D / 4];
|
||||
shared vec4 Qf[Br][HSK / 4];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
|
|
@ -53,18 +54,18 @@ void main() {
|
|||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
vec4 Of[Br][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
vec4 Of[Br][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = vec4(0.0);
|
||||
}
|
||||
|
|
@ -100,8 +101,8 @@ void main() {
|
|||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1) {
|
||||
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
|
|
@ -116,7 +117,7 @@ void main() {
|
|||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
|
|
@ -148,7 +149,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
|
|
@ -195,14 +196,14 @@ void main() {
|
|||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = eMf[r] * Of[r][d];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
|
|
@ -259,7 +260,7 @@ void main() {
|
|||
Lf[r] = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = eMf * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
|
@ -281,11 +282,11 @@ void main() {
|
|||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
|
|
@ -293,7 +294,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
|
|
@ -309,18 +310,18 @@ void main() {
|
|||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
|
|
@ -330,9 +331,9 @@ void main() {
|
|||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (i * Br + r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
|
||||
layout (constant_id = 3) const uint32_t HSK = 32;
|
||||
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
|
|
@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
|
|||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
uint32_t nem2;
|
||||
uint32_t nem3;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
|
|
@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
|
|||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
uint32_t mask_n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
|
|
@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
|
|||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
#define MASK_ENABLE_BIT (1<<16)
|
||||
#define N_LOG2_MASK 0xFFFF
|
||||
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
|
|
@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
|||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,9 @@
|
|||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t row_split = 4;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||
|
|
@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
|
@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
|
|||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
||||
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
|
|
@ -74,18 +76,18 @@ void main() {
|
|||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
}
|
||||
|
|
@ -124,17 +126,17 @@ void main() {
|
|||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1) {
|
||||
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t c = (idx + tid) / (D / 4);
|
||||
if (c < Bc && d < D / 4) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
|
|
@ -149,14 +151,14 @@ void main() {
|
|||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < D / 16; ++d) {
|
||||
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
|
|
@ -180,7 +182,7 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
|
|
@ -206,7 +208,7 @@ void main() {
|
|||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
}
|
||||
|
|
@ -221,7 +223,7 @@ void main() {
|
|||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
Lf[r] += Pf[r];
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
|
|
@ -284,7 +286,7 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
|
@ -304,11 +306,11 @@ void main() {
|
|||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
|
|
@ -316,7 +318,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
|
|
@ -332,18 +334,18 @@ void main() {
|
|||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
|
|
@ -353,9 +355,9 @@ void main() {
|
|||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (i * Br + tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
|||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c < D) {
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
if (r < N && c < HSV) {
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
|
|
@ -86,9 +86,9 @@ void main() {
|
|||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||
#endif
|
||||
|
||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
|
||||
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
|
|
@ -104,16 +104,16 @@ void main() {
|
|||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
|
|
@ -131,8 +131,8 @@ void main() {
|
|||
}
|
||||
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1) {
|
||||
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
|
|
@ -140,10 +140,10 @@ void main() {
|
|||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
|
|
@ -153,7 +153,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
|
@ -208,42 +208,42 @@ void main() {
|
|||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||
// across the row
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
|
||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
return;
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
|
@ -255,18 +255,18 @@ void main() {
|
|||
|
||||
O = Ldiag*O;
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
} else {
|
||||
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
|
||||
|
||||
// permute dimensions
|
||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
const float p_erf = 0.3275911f;
|
||||
const float a1_erf = 0.254829592f;
|
||||
const float a2_erf = -0.284496736f;
|
||||
const float a3_erf = 1.421413741f;
|
||||
const float a4_erf = -1.453152027f;
|
||||
const float a5_erf = 1.061405429f;
|
||||
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
float op(float a, float b) {
|
||||
const float a_div_sqr2 = a * SQRT_2_INV;
|
||||
const float sign_x = sign(a_div_sqr2);
|
||||
const float x = abs(a_div_sqr2);
|
||||
const float t = 1.0f / (1.0f + p_erf * x);
|
||||
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
const float erf_approx = sign_x * y;
|
||||
|
||||
return 0.5f * a * (1.0f + erf_approx) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
float op(float a, float b) {
|
||||
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
|
|
@ -500,10 +500,9 @@ void main() {
|
|||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
|
|
@ -512,22 +511,16 @@ void main() {
|
|||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32;
|
||||
const uint ib16 = ib8 / 2;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
|
|
@ -538,21 +531,17 @@ void main() {
|
|||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4;
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||
|
|
@ -562,63 +551,81 @@ void main() {
|
|||
data_a[ib].qs[8*ib32 + 6],
|
||||
data_a[ib].qs[8*ib32 + 7]
|
||||
));
|
||||
const float db = d * 0.25 * (0.5 + (signs >> 28));
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xxs_grid[qs];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4; // 0..3
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4; // 0..3
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
|
|
@ -631,33 +638,36 @@ void main() {
|
|||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint qh = data_a[ib].qh[iqh];
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||
const uint scale = data_a[ib].scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
|
|
|||
|
|
@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||
|
||||
for (const auto& tname : type_names) {
|
||||
std::string load_vec_quant = "2";
|
||||
if ((tname == "q4_0") || (tname == "q4_1"))
|
||||
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||
load_vec_quant = "8";
|
||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
|
||||
load_vec_quant = "4";
|
||||
|
||||
if (tname == "bf16") {
|
||||
|
|
@ -593,6 +593,10 @@ void process_shaders() {
|
|||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
|
|
|||
|
|
@ -1140,9 +1140,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
|||
"REGLU",
|
||||
"GEGLU",
|
||||
"SWIGLU",
|
||||
"GEGLU_ERF",
|
||||
"GEGLU_QUICK",
|
||||
};
|
||||
|
||||
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
|
||||
static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
|
||||
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
|
@ -2768,6 +2770,48 @@ struct ggml_tensor * ggml_swiglu_split(
|
|||
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
||||
}
|
||||
|
||||
// ggml_geglu_erf
|
||||
|
||||
struct ggml_tensor * ggml_geglu_erf(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_geglu_erf_swapped(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_geglu_erf_split(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
|
||||
}
|
||||
|
||||
// ggml_geglu_quick
|
||||
|
||||
struct ggml_tensor * ggml_geglu_quick(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_geglu_quick_swapped(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_geglu_quick_split(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
|
||||
}
|
||||
|
||||
// ggml_norm
|
||||
|
||||
static struct ggml_tensor * ggml_norm_impl(
|
||||
|
|
@ -6050,13 +6094,28 @@ static void ggml_compute_backward(
|
|||
}
|
||||
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
||||
} break;
|
||||
case GGML_OP_GLU: {
|
||||
switch (ggml_get_glu_op(tensor)) {
|
||||
case GGML_GLU_OP_SWIGLU: {
|
||||
if (src0_needs_grads) {
|
||||
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
|
||||
} //break;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_NONE: {
|
||||
// noop
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
default: {
|
||||
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
||||
} //break;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -359,6 +359,7 @@ class MODEL_ARCH(IntEnum):
|
|||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -663,6 +664,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.FALCON_H1: "falcon_h1",
|
||||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -2248,6 +2250,27 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
|
||||
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
|
||||
],
|
||||
MODEL_ARCH.HUNYUAN_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -305,6 +305,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"model.layers.{bid}.feed_forward.router", # llama4
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
|
@ -365,6 +366,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
|
|
@ -401,6 +403,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
|
|
@ -450,11 +453,13 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
|
|
@ -464,6 +469,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ extern "C" {
|
|||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -1719,6 +1720,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ enum llm_arch {
|
|||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
|||
|
||||
// note: tracking the other way around is not necessary for now
|
||||
//seq_cpl[s0][s1] = true;
|
||||
|
||||
has_cpl = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|||
return n_outputs;
|
||||
}
|
||||
|
||||
uint32_t llama_batch_allocr::get_n_used() const {
|
||||
return n_used;
|
||||
}
|
||||
|
||||
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
||||
return out_ids;
|
||||
}
|
||||
|
|
@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|||
void llama_batch_allocr::split_reset() {
|
||||
out_ids.clear();
|
||||
|
||||
n_used = 0;
|
||||
|
||||
used.clear();
|
||||
used.resize(get_n_tokens(), false);
|
||||
|
||||
|
|
@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx;
|
||||
|
||||
|
|
@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||
return ubatch_add(idxs, idxs.size(), false);
|
||||
}
|
||||
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||
if (sequential && has_cpl) {
|
||||
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<seq_set_t> cur_seq_set;
|
||||
|
||||
llama_seq_id last_seq_id = -1;
|
||||
|
||||
// determine the non-overlapping sequence sets participating in this ubatch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (used[i]) {
|
||||
|
|
@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
// accept only increasing sequence ids
|
||||
if (sequential) {
|
||||
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
||||
}
|
||||
|
||||
if (add) {
|
||||
cur_seq_set.push_back(seq_set[i]);
|
||||
|
||||
last_seq_id = batch.seq_id[i][0];
|
||||
|
||||
if (cur_seq_set.size() > n_ubatch) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|||
idxs_per_seq[s].push_back(idx);
|
||||
|
||||
used[idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx[s];
|
||||
}
|
||||
|
|
@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
if (idxs.size() >= n_ubatch) {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ public:
|
|||
|
||||
uint32_t get_n_tokens() const;
|
||||
uint32_t get_n_outputs() const;
|
||||
uint32_t get_n_used() const;
|
||||
|
||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||
std::vector<int32_t> & get_out_ids();
|
||||
|
|
@ -69,7 +70,8 @@ public:
|
|||
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||
|
||||
// make ubatches of equal-length sequences sets
|
||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
||||
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
||||
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
||||
|
||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||
|
|
@ -112,6 +114,9 @@ private:
|
|||
using pos_set_t = std::set<llama_pos>;
|
||||
using seq_cpl_t = std::vector<bool>;
|
||||
|
||||
// helper flag to quickly determine if there are any coupled sequences in the batch
|
||||
bool has_cpl;
|
||||
|
||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
|
||||
|
|
@ -125,6 +130,8 @@ private:
|
|||
// batch indices of the output
|
||||
std::vector<int32_t> out_ids;
|
||||
|
||||
uint32_t n_used;
|
||||
|
||||
// used[i] indicates if token i has already been used in a previous ubatch
|
||||
std::vector<bool> used;
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
|
|
@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
|
|
@ -665,6 +668,21 @@ int32_t llm_chat_apply_template(
|
|||
if (add_ass) {
|
||||
ss << "<|response|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
||||
// tencent/Hunyuan-A13B-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
||||
} else {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|startoftext|>";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ enum llm_chat_template {
|
|||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1005,8 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1143,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp_kq_mask, "KQ_mask", -1);
|
||||
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->kq_mask);
|
||||
|
||||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
||||
|
|
@ -1209,7 +1207,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1343,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|||
|
||||
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->cross_kq_mask);
|
||||
|
||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||
|
|
@ -1457,7 +1455,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1471,7 +1469,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
|
|
|||
|
|
@ -228,8 +228,8 @@ public:
|
|||
|
||||
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
@ -257,8 +257,8 @@ public:
|
|||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
@ -293,10 +293,10 @@ public:
|
|||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
@ -313,8 +313,8 @@ public:
|
|||
|
||||
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
||||
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
||||
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
||||
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
||||
|
||||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
|
@ -343,8 +343,8 @@ public:
|
|||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
|
|||
|
|
@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
|
|
@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, false);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
|
@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos = prepare(ubatches);
|
||||
if (sinfos.empty()) {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
|
|
|
|||
|
|
@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_57B_A14B: return "57B.A14B";
|
||||
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||
case LLM_TYPE_A13B: return "A13B";
|
||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
||||
case LLM_TYPE_E2B: return "E2B";
|
||||
|
|
@ -1574,6 +1575,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
type = LLM_TYPE_3B; break;
|
||||
case 44:
|
||||
type = LLM_TYPE_7B; break;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_A13B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
|
|
@ -4578,6 +4589,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -5773,12 +5820,10 @@ struct llm_build_falcon : public llm_graph_context {
|
|||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
|
|
@ -6055,12 +6100,10 @@ struct llm_build_dbrx : public llm_graph_context {
|
|||
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
||||
cb(cur, "wqkv_clamped", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
|
|
@ -6571,12 +6614,10 @@ struct llm_build_neo_bert : public llm_graph_context {
|
|||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// RoPE
|
||||
|
|
@ -6806,8 +6847,8 @@ struct llm_build_mpt : public llm_graph_context {
|
|||
cb(cur, "wqkv_clamped", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
|
@ -6827,6 +6868,12 @@ struct llm_build_mpt : public llm_graph_context {
|
|||
model.layers[il].attn_k_norm_b,
|
||||
LLM_NORM, il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
} else {
|
||||
Qcur = ggml_cont(ctx0, Qcur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_cont(ctx0, Kcur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
|
@ -7081,12 +7128,10 @@ struct llm_build_qwen : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
|
|
@ -7851,21 +7896,21 @@ struct llm_build_phi2 : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
|
|
@ -7989,21 +8034,21 @@ struct llm_build_phi3 : public llm_graph_context {
|
|||
cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
|
|
@ -8359,12 +8404,10 @@ struct llm_build_codeshell : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
|
|
@ -8780,8 +8823,6 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
|||
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
|
||||
kv_compressed = ggml_cont(ctx0, kv_compressed);
|
||||
kv_compressed = build_norm(kv_compressed,
|
||||
model.layers[il].attn_kv_a_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
|
|
@ -8808,12 +8849,6 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
|||
v_states = ggml_cont(ctx0, v_states);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
||||
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
||||
0);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
|
|
@ -8822,7 +8857,6 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
|||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// shared RoPE key
|
||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
|
||||
k_pe = ggml_rope_ext(
|
||||
ctx0, k_pe, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
|
|
@ -10887,10 +10921,10 @@ struct llm_build_openelm : public llm_graph_context {
|
|||
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
|
||||
|
||||
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
|
||||
|
|
@ -11012,12 +11046,10 @@ struct llm_build_gptneox : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
|
|
@ -12262,6 +12294,8 @@ struct llm_build_chatglm : public llm_graph_context {
|
|||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
|
@ -12269,13 +12303,11 @@ struct llm_build_chatglm : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
|
||||
|
|
@ -12396,6 +12428,8 @@ struct llm_build_glm4 : public llm_graph_context {
|
|||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
|
@ -12403,13 +12437,11 @@ struct llm_build_glm4 : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
|
|
@ -15024,6 +15056,168 @@ struct llm_build_arcee : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_hunyuan_moe : public llm_graph_context {
|
||||
llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur,
|
||||
model.layers[il].attn_k_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
|
||||
Qcur = build_norm(Qcur,
|
||||
model.layers[il].attn_q_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network (non-MoE)
|
||||
ggml_tensor * cur_mlp = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur_mlp, "ffn_mlp", il);
|
||||
|
||||
// MoE branch
|
||||
ggml_tensor * cur_moe = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
true, // norm_topk_prob
|
||||
false,
|
||||
0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur_moe, "ffn_moe_out", il);
|
||||
|
||||
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
|
||||
cb(ffn_out, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, ffn_out, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
|
|
@ -15407,6 +15601,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
|||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
{
|
||||
llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
|
||||
|
|
@ -15600,6 +15798,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_EXAONE:
|
||||
case LLM_ARCH_MINICPM3:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -94,6 +94,7 @@ enum llm_type {
|
|||
LLM_TYPE_57B_A14B,
|
||||
LLM_TYPE_17B_16E, // llama4 Scout
|
||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||
LLM_TYPE_A13B,
|
||||
LLM_TYPE_30B_A3B,
|
||||
LLM_TYPE_235B_A22B,
|
||||
LLM_TYPE_E2B,
|
||||
|
|
|
|||
|
|
@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
|
|
@ -1657,6 +1658,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "seed-coder") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "hunyuan") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1405,8 +1405,7 @@ struct clip_graph {
|
|||
ggml_tensor * x = embeddings;
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
|
||||
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
|
||||
embeddings = ggml_silu_inplace(ctx0, embeddings);
|
||||
embeddings = ggml_mul(ctx0, embeddings,x);
|
||||
embeddings = ggml_swiglu_split(ctx0, embeddings, x);
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
||||
}
|
||||
// arrangement of BOI/EOI token embeddings
|
||||
|
|
@ -1502,15 +1501,8 @@ struct clip_graph {
|
|||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
|
||||
// swiglu
|
||||
{
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
x1 = ggml_silu(ctx0, x1);
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
}
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
cur = ggml_swiglu_swapped(ctx0, cur);
|
||||
|
||||
// mid-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
|
|
@ -1769,35 +1761,42 @@ private:
|
|||
cur = tmp;
|
||||
}
|
||||
|
||||
// we only support parallel ffn for now
|
||||
switch (type_op) {
|
||||
case FFN_SILU:
|
||||
{
|
||||
if (gate) {
|
||||
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_swiglu", il);
|
||||
} else {
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
cb(cur, "ffn_silu", il);
|
||||
} break;
|
||||
case FFN_GELU:
|
||||
{
|
||||
if (gate) {
|
||||
cur = ggml_geglu_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_geglu", il);
|
||||
} else {
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cb(cur, "ffn_gelu", il);
|
||||
} break;
|
||||
case FFN_GELU_ERF:
|
||||
{
|
||||
if (gate) {
|
||||
cur = ggml_geglu_erf_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_geglu_erf", il);
|
||||
} else {
|
||||
cur = ggml_gelu_erf(ctx0, cur);
|
||||
cb(cur, "ggml_gelu_erf", il);
|
||||
cb(cur, "ffn_gelu_erf", il);
|
||||
} break;
|
||||
case FFN_GELU_QUICK:
|
||||
{
|
||||
if (gate) {
|
||||
cur = ggml_geglu_quick_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_geglu_quick", il);
|
||||
} else {
|
||||
cur = ggml_gelu_quick(ctx0, cur);
|
||||
cb(cur, "ffn_relu", il);
|
||||
cb(cur, "ffn_gelu_quick", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
// we only support parallel ffn for now
|
||||
if (gate) {
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
}
|
||||
|
||||
if (down) {
|
||||
cur = ggml_mul_mat(ctx0, down, cur);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -132,6 +132,28 @@ def test_chat_template():
|
|||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefill,re_prefill", [
|
||||
("Whill", "Whill"),
|
||||
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
|
||||
])
|
||||
def test_chat_template_assistant_prefill(prefill, re_prefill):
|
||||
global server
|
||||
server.chat_template = "llama3"
|
||||
server.debug = True # to get the "__verbose" object in the response
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
{"role": "assistant", "content": prefill},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__verbose" in res.body
|
||||
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
|
||||
|
||||
|
||||
def test_apply_chat_template():
|
||||
global server
|
||||
server.chat_template = "command-r"
|
||||
|
|
@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
|
|||
[{"role": "system", "content": 123}],
|
||||
# [{"content": "hello"}], # TODO: should not be a valid case
|
||||
[{"role": "system", "content": "test"}, {}],
|
||||
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
|
||||
])
|
||||
def test_invalid_chat_completion_req(messages):
|
||||
global server
|
||||
|
|
|
|||
|
|
@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse(
|
|||
|
||||
/* Append assistant prefilled message */
|
||||
if (prefill_assistant_message) {
|
||||
chat_params.prompt += last_message.content;
|
||||
if (!last_message.content_parts.empty()) {
|
||||
for (auto & p : last_message.content_parts) {
|
||||
chat_params.prompt += p.text;
|
||||
}
|
||||
} else {
|
||||
chat_params.prompt += last_message.content;
|
||||
}
|
||||
}
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
|
|
|
|||
Loading…
Reference in New Issue