Merge branch 'master' into compilade/test-model-random

This commit is contained in:
Francis Couture-Harpin 2025-07-08 16:41:45 -04:00
commit 18d2055124
31 changed files with 955 additions and 164 deletions

View File

@ -342,7 +342,7 @@ jobs:
cd build cd build
export GGML_VK_VISIBLE_DEVICES=0 export GGML_VK_VISIBLE_DEVICES=0
# This is using llvmpipe and runs slower than other backends # This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 3600 ctest -L main --verbose --timeout 4200
ubuntu-22-cmake-hip: ubuntu-22-cmake-hip:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04

View File

@ -2734,6 +2734,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.public_path = value; params.public_path = value;
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
add_opt(common_arg(
{"--api-prefix"}, "PREFIX",
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
[](common_params & params, const std::string & value) {
params.api_prefix = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
add_opt(common_arg( add_opt(common_arg(
{"--no-webui"}, {"--no-webui"},
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"), string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),

View File

@ -370,6 +370,7 @@ struct common_params {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT std::string public_path = ""; // NOLINT
std::string api_prefix = ""; // NOLINT
std::string chat_template = ""; // NOLINT std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT bool use_jinja = false; // NOLINT
bool enable_chat_template = true; bool enable_chat_template = true;

View File

@ -815,6 +815,9 @@ class TextModel(ModelBase):
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b" res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -6535,6 +6538,160 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
super().set_gguf_parameters() super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# For handling tied embeddings
self._tok_embd = None
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
# 1. Get the pre-tokenizer identifier hash
tokpre = self.get_vocab_base_pre(tokenizer)
# 2. Reverse-engineer the merges list from mergeable_ranks
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2: # todo this is an assert in Qwen, why?
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
# 3. Generate the tokens and toktypes lists
vocab_size = self.hparams["vocab_size"]
assert tokenizer.vocab_size == vocab_size
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)
# 4. Write all vocab-related fields to the GGUF writer
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)
# 5. Add special tokens and chat templates
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# FIX for BOS token: Overwrite incorrect id read from config.json
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
moe_intermediate_size = hparams["moe_intermediate_size"]
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
moe_topk = hparams["moe_topk"]
assert all(topk == moe_topk[0] for topk in moe_topk)
self.gguf_writer.add_expert_used_count(moe_topk[0])
moe_shared_expert = hparams["num_shared_expert"]
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
# Rope
rope_scaling = hparams.get("rope_scaling", {})
if rope_scaling.get("type") == "dynamic":
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
alpha = rope_scaling.get("alpha", 1000)
base = hparams.get("rope_theta", 10000.0)
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "model.embed_tokens.weight":
self._tok_embd = data_torch.clone()
if name == "lm_head.weight":
if self.hparams.get("tie_word_embeddings", False):
logger.info("Skipping tied output layer 'lm_head.weight'")
return []
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("SmolLM3ForCausalLM")
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View File

@ -137,6 +137,7 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
] ]

View File

@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
### 2. Define the model architecture in `llama.cpp` ### 2. Define the model architecture in `llama.cpp`
The model params and tensors layout must be defined in `llama.cpp`: The model params and tensors layout must be defined in `llama.cpp` source files:
1. Define a new `llm_arch` 1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
2. Define the tensors layout in `LLM_TENSOR_NAMES` 2. In `src/llama-arch.cpp`:
3. Add any non-standard metadata in `llm_load_hparams` - Add the architecture name to the `LLM_ARCH_NAMES` map.
4. Create the tensors for inference in `llm_load_tensors` - Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
5. If the model has a RoPE operation, add the rope type in `llama_rope_type` 3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions. NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
### 3. Build the GGML graph implementation ### 3. Build the GGML graph implementation
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.

View File

@ -495,7 +495,7 @@ extern "C" {
GGML_OP_POOL_1D, GGML_OP_POOL_1D,
GGML_OP_POOL_2D, GGML_OP_POOL_2D,
GGML_OP_POOL_2D_BACK, GGML_OP_POOL_2D_BACK,
GGML_OP_UPSCALE, // nearest interpolate GGML_OP_UPSCALE,
GGML_OP_PAD, GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D, GGML_OP_PAD_REFLECT_1D,
GGML_OP_ROLL, GGML_OP_ROLL,

View File

@ -176,17 +176,20 @@ static const char * cu_get_error_str(CUresult err) {
#endif #endif
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \ do { \
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
const int id = ggml_cuda_get_device(); \ const int id = ggml_cuda_get_device(); \
if (!shared_memory_limit_raised[id]) { \ if (!shared_memory_limit_raised[id]) { \
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
shared_memory_limit_raised[id] = true; \ shared_memory_limit_raised[id] = true; \
} \ } \
} while (0) } while (0)
#else #else
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0) # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
GGML_UNUSED(nbytes); \
} while (0)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)

View File

@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne2); GGML_UNUSED(ne3); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE #endif // FLASH_ATTN_AVAILABLE
} }

View File

@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb31); GGML_UNUSED(nb32);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(ne2); GGML_UNUSED(ne3); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE #endif // FLASH_ATTN_AVAILABLE
} }

View File

@ -3375,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:

View File

@ -50,21 +50,19 @@ static __global__ void rope_norm(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int row_x = row_dst % ne1; const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1; const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0; const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0; const int ix = channel_x*s2 + row_x*s1 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = x[ix + 0];
dst[idst + 1] = x[ix + 1];
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -94,21 +92,19 @@ static __global__ void rope_neox(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int row_x = row_dst % ne1; const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1; const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2; const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -138,21 +134,19 @@ static __global__ void rope_multi(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int row_x = row_dst % ne1; const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1; const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2; const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0]; const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims; const int sector = (i0 / 2) % sect_dims;

View File

@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ); dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
} }
static __global__ void upscale_f32_bilinear(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset) {
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
if (index >= dst_total_elements) {
return;
}
const int i10_dst = index % ne10_dst;
const int i11_dst = (index / ne10_dst) % ne11_dst;
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
const int i02_src = (int)(i12_dst / sf2);
const int i03_src = (int)(i13_dst / sf3);
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
int y0_src = (int)floorf(y_src_f);
int y1_src = y0_src + 1;
y0_src = max(0, min(y0_src, ne01_src - 1));
y1_src = max(0, min(y1_src, ne01_src - 1));
float dy = y_src_f - (float)y0_src;
dy = max(0.0f, min(dy, 1.0f));
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
int x0_src = (int)floorf(x_src_f);
int x1_src = x0_src + 1;
x0_src = max(0, min(x0_src, ne00_src - 1));
x1_src = max(0, min(x1_src, ne00_src - 1));
float dx = x_src_f - (float)x0_src;
dx = max(0.0f, min(dx, 1.0f));
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float val_a = *p_a;
const float val_b = *p_b;
const float val_c = *p_c;
const float val_d = *p_d;
float result = val_a * (1.0f - dx) * (1.0f - dy) +
val_b * dx * (1.0f - dy) +
val_c * (1.0f - dx) * dy +
val_d * dx * dy;
dst[index] = result;
}
static void upscale_f32_cuda(const float * x, float * dst, static void upscale_f32_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03, const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int ne13, const int ne10, const int ne11, const int ne12, const int ne13,
const float sf0, const float sf1, const float sf2, const float sf3, const float sf0, const float sf1, const float sf2, const float sf3,
cudaStream_t stream) { cudaStream_t stream) {
int dst_size = ne10 * ne11 * ne12 * ne13; const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
} }
static void upscale_f32_bilinear_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset, cudaStream_t stream) {
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
}
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;
@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float sf0 = (float)dst->ne[0]/src0->ne[0]; const int mode_flags = dst->op_params[0];
const float sf1 = (float)dst->ne[1]/src0->ne[1]; const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
const float sf2 = (float)dst->ne[2]/src0->ne[2];
float sf0 = (float)dst->ne[0]/src0->ne[0];
float sf1 = (float)dst->ne[1]/src0->ne[1];
float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3]; const float sf3 = (float)dst->ne[3]/src0->ne[3];
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); if (mode == GGML_SCALE_MODE_NEAREST) {
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
pixel_offset = 0.0f;
}
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream);
}
} }

View File

@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
if (i0 >= n_dims) {
const int i = row * ne0 + i0;
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
return;
}
const int row0 = row % ne1; const int row0 = row % ne1;
const int channel0 = row / ne1; const int channel0 = row / ne1;
const int i = row * ne0 + i0; const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0; const int i2 = channel0 * s2 + row0 * s1 + i0;
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
if (i0 >= n_dims) {
const int i = row * ne0 + i0;
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
return;
}
const int row0 = row % ne1; const int row0 = row % ne1;
const int channel0 = row / ne1; const int channel0 = row / ne1;
const int i = row * ne0 + i0 / 2; const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
} }
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
return;
}
const int row_x = row_dst % ne1; const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1; const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2); const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0]; const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims; const int sector = (i0 / 2) % sect_dims;

View File

@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
// Try to use split_k when KV is large enough to be worth the overhead // Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { if (workgroups_x == 1 && shader_core_count > 0) {
// Try to run two workgroups per SM. // Try to run two workgroups per SM.
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) { if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple // Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that. // of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
split_k = CEIL_DIV(KV, split_kv); split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k; workgroups_x = split_k;
} }
@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
}, },
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 }); pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
} else { } else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {

View File

@ -2,9 +2,9 @@
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 32 layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];}; layout (binding = 1) writeonly buffer D {float data_d[];};
@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
uint k_num; uint k_num;
} p; } p;
shared float tmpsh[BLOCK_SIZE];
void main() { void main() {
// Each workgroup handles a row // Each workgroup handles a row
const uint n = gl_WorkGroupID.x; const uint n = gl_WorkGroupID.x;
@ -32,23 +34,51 @@ void main() {
// Compute the max m value for the row // Compute the max m value for the row
float m_max = -1.0/0.0; float m_max = -1.0/0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) { for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float m = data_a[m_offset + k * lm_stride]; float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m); m_max = max(m_max, m);
} }
// reduce across the workgroup
tmpsh[tid] = m_max;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
m_max = max(m_max, tmpsh[tid + s]);
tmpsh[tid] = m_max;
}
barrier();
}
m_max = tmpsh[0];
barrier();
// Compute L based on m_max // Compute L based on m_max
float L = 0; float L = 0;
[[unroll]] for (uint k = 0; k < k_num; ++k) { for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float l = data_a[l_offset + k * lm_stride]; float l = data_a[l_offset + (k + tid) * lm_stride];
float m = data_a[m_offset + k * lm_stride]; float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l; L += exp(m - m_max) * l;
} }
// reduce across the workgroup
tmpsh[tid] = L;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
L += tmpsh[tid + s];
tmpsh[tid] = L;
}
barrier();
}
L = tmpsh[0];
L = 1.0 / L; L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory // Scale and sum the O contributions based on m_max and store the result to memory
for (uint d = tid; d < D; d += BLOCK_SIZE) { if (d < D) {
float O = 0.0; float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) { [[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;

View File

@ -14,21 +14,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x; const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1; const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1; const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2; const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
return;
}
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0]; const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims; const uint sector = (i0 / 2) % sect_dims;

View File

@ -13,21 +13,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x; const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1; const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1; const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2; const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
return;
}
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View File

@ -13,21 +13,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x; const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1; const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1; const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0; const uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
if (i0 >= p.n_dims) {
data_d[idst + 0] = data_a[ix + 0];
data_d[idst + 1] = data_a[ix + 1];
return;
}
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View File

@ -357,6 +357,8 @@ class MODEL_ARCH(IntEnum):
DOTS1 = auto() DOTS1 = auto()
ARCEE = auto() ARCEE = auto()
ERNIE4_5 = auto() ERNIE4_5 = auto()
HUNYUAN_MOE = auto()
SMOLLM3 = auto()
class VISION_PROJECTOR_TYPE(IntEnum): class VISION_PROJECTOR_TYPE(IntEnum):
@ -660,6 +662,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.SMOLLM3: "smollm3",
} }
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -2211,6 +2215,43 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.HUNYUAN_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.SMOLLM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO # TODO
} }

View File

@ -303,6 +303,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"model.layers.{bid}.feed_forward.router", # llama4 "model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
), ),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -362,6 +363,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
), ),
# AWQ-activation gate # AWQ-activation gate
@ -398,6 +400,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
), ),
# Feed-forward down # Feed-forward down
@ -447,11 +450,13 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe "model.layers.{bid}.shared_mlp.output_linear", # granitemoe
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
), ),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm", "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.q_ln", # sea-lion "transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
@ -461,6 +466,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: ( MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm", "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.k_ln", # sea-lion "transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2

View File

@ -117,6 +117,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
}; };
enum llama_rope_type { enum llama_rope_type {

View File

@ -78,6 +78,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -1694,12 +1696,52 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
}, },
}, },
{
LLM_ARCH_SMOLLM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
}; };
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {

View File

@ -82,6 +82,8 @@ enum llm_arch {
LLM_ARCH_DOTS1, LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE, LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };

View File

@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
}; };
llm_chat_template llm_chat_template_from_str(const std::string & name) { llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA4; return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) { } else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1; return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} }
return LLM_CHAT_TEMPLATE_UNKNOWN; return LLM_CHAT_TEMPLATE_UNKNOWN;
} }
@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|response|>"; ss << "<|response|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << "<|startoftext|>" << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View File

@ -44,6 +44,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN, LLM_CHAT_TEMPLATE_UNKNOWN,
}; };

View File

@ -102,6 +102,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E2B: return "E2B";
@ -1551,6 +1552,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_HUNYUAN_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_A13B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_SMOLLM3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
hparams.n_no_rope_layer_step = 4;
switch (hparams.n_layer) {
case 36: type = LLM_TYPE_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture"); default: throw std::runtime_error("unsupported model architecture");
} }
@ -4471,6 +4493,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_HUNYUAN_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
}
} break;
case LLM_ARCH_SMOLLM3:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
@ -14647,6 +14735,304 @@ struct llm_build_arcee : public llm_graph_context {
} }
}; };
struct llm_build_hunyuan_moe : public llm_graph_context {
llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = build_norm(Kcur,
model.layers[il].attn_k_norm, nullptr,
LLM_NORM_RMS, il);
cb(Kcur, "Kcur_norm", il);
Qcur = build_norm(Qcur,
model.layers[il].attn_q_norm, nullptr,
LLM_NORM_RMS, il);
cb(Qcur, "Qcur_norm", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// feed-forward network (non-MoE)
ggml_tensor * cur_mlp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur_mlp, "ffn_mlp", il);
// MoE branch
ggml_tensor * cur_moe = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU,
true, // norm_topk_prob
false,
0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(cur_moe, "ffn_moe_out", il);
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
cb(ffn_out, "ffn_out", il);
cur = ggml_add(ctx0, ffn_out, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_smollm3 : public llm_graph_context {
llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (use_rope) {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
{
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res; llama_memory_i * res;
@ -15027,6 +15413,14 @@ llm_graph_result_ptr llama_model::build_graph(
{ {
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf); llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
} break; } break;
case LLM_ARCH_HUNYUAN_MOE:
{
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
} break;
case LLM_ARCH_SMOLLM3:
{
llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
} break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -15178,6 +15572,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
case LLM_ARCH_BAILINGMOE: case LLM_ARCH_BAILINGMOE:
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
case LLM_ARCH_SMOLLM3:
case LLM_ARCH_ARCEE: case LLM_ARCH_ARCEE:
case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5:
return LLAMA_ROPE_TYPE_NORM; return LLAMA_ROPE_TYPE_NORM;
@ -15215,6 +15610,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE:
case LLM_ARCH_MINICPM3: case LLM_ARCH_MINICPM3:
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
return LLAMA_ROPE_TYPE_NEOX; return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL: case LLM_ARCH_QWEN2VL:

View File

@ -94,6 +94,7 @@ enum llm_type {
LLM_TYPE_57B_A14B, LLM_TYPE_57B_A14B,
LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_30B_A3B, LLM_TYPE_30B_A3B,
LLM_TYPE_235B_A22B, LLM_TYPE_235B_A22B,
LLM_TYPE_E2B, LLM_TYPE_E2B,

View File

@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
break; break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_QWEN2:
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
regex_exprs = { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
@ -1656,6 +1657,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "seed-coder") { tokenizer_pre == "seed-coder") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
clean_spaces = false; clean_spaces = false;
} else if (
tokenizer_pre == "hunyuan") {
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
clean_spaces = false;
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }

View File

@ -5323,12 +5323,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (bool fw : {true, false}) { // fw == forward for (bool fw : {true, false}) { // fw == forward
bool all = true; bool all = true;
for (float v : { 0, 1 }) { for (float fs : { 1.0f, 1.4245f }) {
for (float fs : { 1.0f, 1.4245f }) { for (float ef : { 0.0f, 0.7465f }) {
for (float ef : { 0.0f, 0.7465f }) { for (float af : { 1.0f, 1.4245f }) {
for (float af : { 1.0f, 1.4245f }) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors
for (bool ff : {false, true}) { // freq_factors for (float v : { 0, 1 }) {
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
if (all) { if (all) {
@ -5341,13 +5341,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
} }
if (all) { if (all) {
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
} }

View File

@ -4806,14 +4806,14 @@ int main(int argc, char ** argv) {
// register static assets routes // register static assets routes
if (!params.public_path.empty()) { if (!params.public_path.empty()) {
// Set the base directory for serving static files // Set the base directory for serving static files
bool is_found = svr->set_mount_point("/", params.public_path); bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path);
if (!is_found) { if (!is_found) {
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
return 1; return 1;
} }
} else { } else {
// using embedded static index.html // using embedded static index.html
svr->Get("/", [](const httplib::Request & req, httplib::Response & res) { svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
res.set_content("Error: gzip is not supported by this browser", "text/plain"); res.set_content("Error: gzip is not supported by this browser", "text/plain");
} else { } else {
@ -4829,37 +4829,37 @@ int main(int argc, char ** argv) {
} }
// register API routes // register API routes
svr->Get ("/health", handle_health); // public endpoint (no API key check) svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
svr->Get ("/metrics", handle_metrics); svr->Get (params.api_prefix + "/metrics", handle_metrics);
svr->Get ("/props", handle_props); svr->Get (params.api_prefix + "/props", handle_props);
svr->Post("/props", handle_props_change); svr->Post(params.api_prefix + "/props", handle_props_change);
svr->Post("/api/show", handle_api_show); svr->Post(params.api_prefix + "/api/show", handle_api_show);
svr->Get ("/models", handle_models); // public endpoint (no API key check) svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check)
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check)
svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
svr->Post("/completion", handle_completions); // legacy svr->Post(params.api_prefix + "/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions); svr->Post(params.api_prefix + "/completions", handle_completions);
svr->Post("/v1/completions", handle_completions_oai); svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai);
svr->Post("/chat/completions", handle_chat_completions); svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions); svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions);
svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint
svr->Post("/infill", handle_infill); svr->Post(params.api_prefix + "/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings); svr->Post(params.api_prefix + "/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai); svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank); svr->Post(params.api_prefix + "/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank); svr->Post(params.api_prefix + "/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank); svr->Post(params.api_prefix + "/v1/rerank", handle_rerank);
svr->Post("/v1/reranking", handle_rerank); svr->Post(params.api_prefix + "/v1/reranking", handle_rerank);
svr->Post("/tokenize", handle_tokenize); svr->Post(params.api_prefix + "/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize); svr->Post(params.api_prefix + "/detokenize", handle_detokenize);
svr->Post("/apply-template", handle_apply_template); svr->Post(params.api_prefix + "/apply-template", handle_apply_template);
// LoRA adapters hotswap // LoRA adapters hotswap
svr->Get ("/lora-adapters", handle_lora_adapters_list); svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list);
svr->Post("/lora-adapters", handle_lora_adapters_apply); svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply);
// Save & load slots // Save & load slots
svr->Get ("/slots", handle_slots); svr->Get (params.api_prefix + "/slots", handle_slots);
svr->Post("/slots/:id_slot", handle_slots_action); svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
// //
// Start the server // Start the server