Merge branch 'master' into compilade/test-model-random
This commit is contained in:
commit
18d2055124
|
|
@ -342,7 +342,7 @@ jobs:
|
||||||
cd build
|
cd build
|
||||||
export GGML_VK_VISIBLE_DEVICES=0
|
export GGML_VK_VISIBLE_DEVICES=0
|
||||||
# This is using llvmpipe and runs slower than other backends
|
# This is using llvmpipe and runs slower than other backends
|
||||||
ctest -L main --verbose --timeout 3600
|
ctest -L main --verbose --timeout 4200
|
||||||
|
|
||||||
ubuntu-22-cmake-hip:
|
ubuntu-22-cmake-hip:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
|
||||||
|
|
@ -2734,6 +2734,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.public_path = value;
|
params.public_path = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--api-prefix"}, "PREFIX",
|
||||||
|
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.api_prefix = value;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--no-webui"},
|
{"--no-webui"},
|
||||||
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
|
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
|
||||||
|
|
|
||||||
|
|
@ -370,6 +370,7 @@ struct common_params {
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
|
std::string api_prefix = ""; // NOLINT
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
bool use_jinja = false; // NOLINT
|
bool use_jinja = false; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
|
|
|
||||||
|
|
@ -815,6 +815,9 @@ class TextModel(ModelBase):
|
||||||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||||
res = "minerva-7b"
|
res = "minerva-7b"
|
||||||
|
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
|
||||||
|
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
|
||||||
|
res = "hunyuan"
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
logger.warning("\n")
|
logger.warning("\n")
|
||||||
|
|
@ -6535,6 +6538,160 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("HunYuanMoEV1ForCausalLM")
|
||||||
|
class HunYuanMoEModel(TextModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# For handling tied embeddings
|
||||||
|
self._tok_embd = None
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||||
|
|
||||||
|
# 1. Get the pre-tokenizer identifier hash
|
||||||
|
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||||
|
|
||||||
|
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||||
|
merges = []
|
||||||
|
vocab = {}
|
||||||
|
mergeable_ranks = tokenizer.mergeable_ranks
|
||||||
|
for token, rank in mergeable_ranks.items():
|
||||||
|
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||||
|
if len(token) == 1:
|
||||||
|
continue
|
||||||
|
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
|
||||||
|
if len(merged) == 2: # todo this is an assert in Qwen, why?
|
||||||
|
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||||
|
|
||||||
|
# 3. Generate the tokens and toktypes lists
|
||||||
|
vocab_size = self.hparams["vocab_size"]
|
||||||
|
assert tokenizer.vocab_size == vocab_size
|
||||||
|
special_tokens = tokenizer.special_tokens
|
||||||
|
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||||
|
tokens: list[str] = []
|
||||||
|
toktypes: list[int] = []
|
||||||
|
for i in range(vocab_size):
|
||||||
|
if i not in reverse_vocab:
|
||||||
|
tokens.append(f"[PAD{i}]")
|
||||||
|
toktypes.append(gguf.TokenType.UNUSED)
|
||||||
|
else:
|
||||||
|
token = reverse_vocab[i]
|
||||||
|
tokens.append(token)
|
||||||
|
if i in special_tokens.values():
|
||||||
|
toktypes.append(gguf.TokenType.CONTROL)
|
||||||
|
else:
|
||||||
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
|
||||||
|
# 4. Write all vocab-related fields to the GGUF writer
|
||||||
|
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||||
|
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||||
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
self.gguf_writer.add_token_merges(merges)
|
||||||
|
|
||||||
|
# 5. Add special tokens and chat templates
|
||||||
|
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||||
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
# FIX for BOS token: Overwrite incorrect id read from config.json
|
||||||
|
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
hparams = self.hparams
|
||||||
|
|
||||||
|
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||||
|
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
|
||||||
|
|
||||||
|
moe_intermediate_size = hparams["moe_intermediate_size"]
|
||||||
|
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
|
||||||
|
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
|
||||||
|
|
||||||
|
moe_topk = hparams["moe_topk"]
|
||||||
|
assert all(topk == moe_topk[0] for topk in moe_topk)
|
||||||
|
self.gguf_writer.add_expert_used_count(moe_topk[0])
|
||||||
|
|
||||||
|
moe_shared_expert = hparams["num_shared_expert"]
|
||||||
|
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
|
||||||
|
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
|
||||||
|
|
||||||
|
# Rope
|
||||||
|
rope_scaling = hparams.get("rope_scaling", {})
|
||||||
|
if rope_scaling.get("type") == "dynamic":
|
||||||
|
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
|
||||||
|
alpha = rope_scaling.get("alpha", 1000)
|
||||||
|
base = hparams.get("rope_theta", 10000.0)
|
||||||
|
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
|
||||||
|
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
|
||||||
|
self.gguf_writer.add_rope_freq_base(scaled_base)
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(1)
|
||||||
|
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||||
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||||
|
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||||
|
|
||||||
|
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||||
|
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||||
|
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||||
|
|
||||||
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
if name == "model.embed_tokens.weight":
|
||||||
|
self._tok_embd = data_torch.clone()
|
||||||
|
|
||||||
|
if name == "lm_head.weight":
|
||||||
|
if self.hparams.get("tie_word_embeddings", False):
|
||||||
|
logger.info("Skipping tied output layer 'lm_head.weight'")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if name.find("mlp.experts") != -1:
|
||||||
|
n_experts = self.hparams["num_experts"]
|
||||||
|
assert bid is not None
|
||||||
|
|
||||||
|
if self._experts is None:
|
||||||
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
|
||||||
|
self._experts[bid][name] = data_torch
|
||||||
|
|
||||||
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
|
# merge the experts into a single 3d tensor
|
||||||
|
tensors: list[tuple[str, Tensor]] = []
|
||||||
|
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||||
|
datas: list[Tensor] = []
|
||||||
|
|
||||||
|
for xid in range(n_experts):
|
||||||
|
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||||
|
datas.append(self._experts[bid][ename])
|
||||||
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
|
data_torch = torch.stack(datas, dim=0)
|
||||||
|
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||||
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
tensors.append((new_name, data_torch))
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
def prepare_tensors(self):
|
||||||
|
super().prepare_tensors()
|
||||||
|
if self._experts is not None:
|
||||||
|
experts = [k for d in self._experts for k in d.keys()]
|
||||||
|
if len(experts) > 0:
|
||||||
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("SmolLM3ForCausalLM")
|
||||||
|
class SmolLM3Model(LlamaModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,6 +137,7 @@ pre_computed_hashes = [
|
||||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||||
|
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
|
||||||
|
|
||||||
### 2. Define the model architecture in `llama.cpp`
|
### 2. Define the model architecture in `llama.cpp`
|
||||||
|
|
||||||
The model params and tensors layout must be defined in `llama.cpp`:
|
The model params and tensors layout must be defined in `llama.cpp` source files:
|
||||||
1. Define a new `llm_arch`
|
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
|
||||||
2. Define the tensors layout in `LLM_TENSOR_NAMES`
|
2. In `src/llama-arch.cpp`:
|
||||||
3. Add any non-standard metadata in `llm_load_hparams`
|
- Add the architecture name to the `LLM_ARCH_NAMES` map.
|
||||||
4. Create the tensors for inference in `llm_load_tensors`
|
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
|
||||||
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
|
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
|
||||||
|
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
|
||||||
|
|
||||||
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
|
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
|
||||||
|
|
||||||
### 3. Build the GGML graph implementation
|
### 3. Build the GGML graph implementation
|
||||||
|
|
||||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
|
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
|
||||||
|
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
|
||||||
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
|
Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
|
||||||
|
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
|
||||||
|
|
||||||
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
|
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -495,7 +495,7 @@ extern "C" {
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
GGML_OP_POOL_2D_BACK,
|
GGML_OP_POOL_2D_BACK,
|
||||||
GGML_OP_UPSCALE, // nearest interpolate
|
GGML_OP_UPSCALE,
|
||||||
GGML_OP_PAD,
|
GGML_OP_PAD,
|
||||||
GGML_OP_PAD_REFLECT_1D,
|
GGML_OP_PAD_REFLECT_1D,
|
||||||
GGML_OP_ROLL,
|
GGML_OP_ROLL,
|
||||||
|
|
|
||||||
|
|
@ -176,17 +176,20 @@ static const char * cu_get_error_str(CUresult err) {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||||
do { \
|
do { \
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||||
const int id = ggml_cuda_get_device(); \
|
const int id = ggml_cuda_get_device(); \
|
||||||
if (!shared_memory_limit_raised[id]) { \
|
if (!shared_memory_limit_raised[id]) { \
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||||
shared_memory_limit_raised[id] = true; \
|
shared_memory_limit_raised[id] = true; \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
#else
|
#else
|
||||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||||
|
do { \
|
||||||
|
GGML_UNUSED(nbytes); \
|
||||||
|
} while (0)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||||
|
|
|
||||||
|
|
@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
|
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3375,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
|
|
||||||
|
|
@ -50,21 +50,19 @@ static __global__ void rope_norm(
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
|
||||||
const int i = row_dst*ne0 + i0;
|
|
||||||
|
|
||||||
dst[i + 0] = x[i + 0];
|
|
||||||
dst[i + 1] = x[i + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const int row_x = row_dst % ne1;
|
||||||
const int channel_x = row_dst / ne1;
|
const int channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const int idst = row_dst*ne0 + i0;
|
const int idst = row_dst*ne0 + i0;
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
dst[idst + 0] = x[ix + 0];
|
||||||
|
dst[idst + 1] = x[ix + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
@ -94,21 +92,19 @@ static __global__ void rope_neox(
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
|
||||||
const int i = row_dst*ne0 + i0;
|
|
||||||
|
|
||||||
dst[i + 0] = x[i + 0];
|
|
||||||
dst[i + 1] = x[i + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const int row_x = row_dst % ne1;
|
||||||
const int channel_x = row_dst / ne1;
|
const int channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const int idst = row_dst*ne0 + i0/2;
|
const int idst = row_dst*ne0 + i0/2;
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||||
|
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
@ -138,21 +134,19 @@ static __global__ void rope_multi(
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
|
||||||
const int i = row_dst*ne0 + i0;
|
|
||||||
|
|
||||||
dst[i + 0] = x[i + 0];
|
|
||||||
dst[i + 1] = x[i + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const int row_x = row_dst % ne1;
|
||||||
const int channel_x = row_dst / ne1;
|
const int channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const int idst = row_dst*ne0 + i0/2;
|
const int idst = row_dst*ne0 + i0/2;
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||||
|
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||||
const int sec_w = sections.v[1] + sections.v[0];
|
const int sec_w = sections.v[1] + sections.v[0];
|
||||||
const int sector = (i0 / 2) % sect_dims;
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
|
|
|
||||||
|
|
@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
|
||||||
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
|
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
||||||
|
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne00_src, const int ne01_src,
|
||||||
|
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||||
|
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||||
|
const float pixel_offset) {
|
||||||
|
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||||
|
|
||||||
|
if (index >= dst_total_elements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i10_dst = index % ne10_dst;
|
||||||
|
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||||
|
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||||
|
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||||
|
|
||||||
|
const int i02_src = (int)(i12_dst / sf2);
|
||||||
|
const int i03_src = (int)(i13_dst / sf3);
|
||||||
|
|
||||||
|
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||||
|
int y0_src = (int)floorf(y_src_f);
|
||||||
|
int y1_src = y0_src + 1;
|
||||||
|
|
||||||
|
y0_src = max(0, min(y0_src, ne01_src - 1));
|
||||||
|
y1_src = max(0, min(y1_src, ne01_src - 1));
|
||||||
|
|
||||||
|
float dy = y_src_f - (float)y0_src;
|
||||||
|
dy = max(0.0f, min(dy, 1.0f));
|
||||||
|
|
||||||
|
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||||
|
int x0_src = (int)floorf(x_src_f);
|
||||||
|
int x1_src = x0_src + 1;
|
||||||
|
|
||||||
|
x0_src = max(0, min(x0_src, ne00_src - 1));
|
||||||
|
x1_src = max(0, min(x1_src, ne00_src - 1));
|
||||||
|
|
||||||
|
float dx = x_src_f - (float)x0_src;
|
||||||
|
dx = max(0.0f, min(dx, 1.0f));
|
||||||
|
|
||||||
|
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||||
|
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||||
|
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||||
|
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||||
|
|
||||||
|
const float val_a = *p_a;
|
||||||
|
const float val_b = *p_b;
|
||||||
|
const float val_c = *p_c;
|
||||||
|
const float val_d = *p_d;
|
||||||
|
|
||||||
|
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
||||||
|
val_b * dx * (1.0f - dy) +
|
||||||
|
val_c * (1.0f - dx) * dy +
|
||||||
|
val_d * dx * dy;
|
||||||
|
|
||||||
|
dst[index] = result;
|
||||||
|
}
|
||||||
|
|
||||||
static void upscale_f32_cuda(const float * x, float * dst,
|
static void upscale_f32_cuda(const float * x, float * dst,
|
||||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
|
||||||
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||||
|
|
||||||
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
||||||
|
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne00_src, const int ne01_src,
|
||||||
|
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||||
|
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||||
|
const float pixel_offset, cudaStream_t stream) {
|
||||||
|
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||||
|
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||||
|
|
||||||
|
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
|
@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
const int mode_flags = dst->op_params[0];
|
||||||
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
|
||||||
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
|
||||||
|
float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||||
|
float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||||
|
float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||||
|
|
||||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||||
|
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||||
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||||
|
float pixel_offset = 0.5f;
|
||||||
|
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||||
|
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||||
|
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||||
|
pixel_offset = 0.0f;
|
||||||
|
}
|
||||||
|
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||||
|
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
||||||
|
|
||||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
|
||||||
const int i = row * ne0 + i0;
|
|
||||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row0 = row % ne1;
|
const int row0 = row % ne1;
|
||||||
const int channel0 = row / ne1;
|
const int channel0 = row / ne1;
|
||||||
|
|
||||||
const int i = row * ne0 + i0;
|
const int i = row * ne0 + i0;
|
||||||
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||||
|
|
@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
||||||
|
|
||||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
|
||||||
const int i = row * ne0 + i0;
|
|
||||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row0 = row % ne1;
|
const int row0 = row % ne1;
|
||||||
const int channel0 = row / ne1;
|
const int channel0 = row / ne1;
|
||||||
|
|
||||||
const int i = row * ne0 + i0 / 2;
|
const int i = row * ne0 + i0 / 2;
|
||||||
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||||
|
|
@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||||
}
|
}
|
||||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
|
||||||
const int i = row_dst*ne0 + i0;
|
|
||||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const int row_x = row_dst % ne1;
|
||||||
const int channel_x = row_dst / ne1;
|
const int channel_x = row_dst / ne1;
|
||||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||||
const int sec_w = sections.v[1] + sections.v[0];
|
const int sec_w = sections.v[1] + sections.v[0];
|
||||||
const int sector = (i0 / 2) % sect_dims;
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
|
|
|
||||||
|
|
@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
||||||
|
|
@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||||
|
|
||||||
// Try to use split_k when KV is large enough to be worth the overhead
|
// Try to use split_k when KV is large enough to be worth the overhead
|
||||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
if (workgroups_x == 1 && shader_core_count > 0) {
|
||||||
// Try to run two workgroups per SM.
|
// Try to run two workgroups per SM.
|
||||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||||
if (split_k > 1) {
|
if (split_k > 1) {
|
||||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||||
// of "align", so recompute split_k based on that.
|
// of "align", so recompute split_k based on that.
|
||||||
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
||||||
split_k = CEIL_DIV(KV, split_kv);
|
split_k = CEIL_DIV(KV, split_kv);
|
||||||
workgroups_x = split_k;
|
workgroups_x = split_k;
|
||||||
}
|
}
|
||||||
|
|
@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||||
},
|
},
|
||||||
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
#define BLOCK_SIZE 32
|
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
|
||||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||||
|
|
@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
|
||||||
uint k_num;
|
uint k_num;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
|
shared float tmpsh[BLOCK_SIZE];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
// Each workgroup handles a row
|
// Each workgroup handles a row
|
||||||
const uint n = gl_WorkGroupID.x;
|
const uint n = gl_WorkGroupID.x;
|
||||||
|
|
@ -32,23 +34,51 @@ void main() {
|
||||||
|
|
||||||
// Compute the max m value for the row
|
// Compute the max m value for the row
|
||||||
float m_max = -1.0/0.0;
|
float m_max = -1.0/0.0;
|
||||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||||
float m = data_a[m_offset + k * lm_stride];
|
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||||
m_max = max(m_max, m);
|
m_max = max(m_max, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reduce across the workgroup
|
||||||
|
tmpsh[tid] = m_max;
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
m_max = max(m_max, tmpsh[tid + s]);
|
||||||
|
tmpsh[tid] = m_max;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
m_max = tmpsh[0];
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
// Compute L based on m_max
|
// Compute L based on m_max
|
||||||
float L = 0;
|
float L = 0;
|
||||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||||
float l = data_a[l_offset + k * lm_stride];
|
float l = data_a[l_offset + (k + tid) * lm_stride];
|
||||||
float m = data_a[m_offset + k * lm_stride];
|
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||||
L += exp(m - m_max) * l;
|
L += exp(m - m_max) * l;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reduce across the workgroup
|
||||||
|
tmpsh[tid] = L;
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
L += tmpsh[tid + s];
|
||||||
|
tmpsh[tid] = L;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
L = tmpsh[0];
|
||||||
|
|
||||||
L = 1.0 / L;
|
L = 1.0 / L;
|
||||||
|
|
||||||
|
// D dimension is split across workgroups in the y dimension
|
||||||
|
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||||
// Scale and sum the O contributions based on m_max and store the result to memory
|
// Scale and sum the O contributions based on m_max and store the result to memory
|
||||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
if (d < D) {
|
||||||
float O = 0.0;
|
float O = 0.0;
|
||||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||||
|
|
|
||||||
|
|
@ -14,21 +14,19 @@ void main() {
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
const uint row_dst = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
|
||||||
const uint i = row_dst*ne0 + i0;
|
|
||||||
|
|
||||||
data_d[i + 0] = data_a[i + 0];
|
|
||||||
data_d[i + 1] = data_a[i + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
const uint row_x = row_dst % ne1;
|
||||||
const uint channel_x = row_dst / ne1;
|
const uint channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0/2;
|
const uint idst = row_dst*ne0 + i0/2;
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||||
|
|
||||||
|
if (i0 >= p.n_dims) {
|
||||||
|
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||||
|
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||||
const int sec_w = p.sections[1] + p.sections[0];
|
const int sec_w = p.sections[1] + p.sections[0];
|
||||||
const uint sector = (i0 / 2) % sect_dims;
|
const uint sector = (i0 / 2) % sect_dims;
|
||||||
|
|
|
||||||
|
|
@ -13,21 +13,19 @@ void main() {
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
const uint row_dst = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
|
||||||
const uint i = row_dst*ne0 + i0;
|
|
||||||
|
|
||||||
data_d[i + 0] = data_a[i + 0];
|
|
||||||
data_d[i + 1] = data_a[i + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
const uint row_x = row_dst % ne1;
|
||||||
const uint channel_x = row_dst / ne1;
|
const uint channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0/2;
|
const uint idst = row_dst*ne0 + i0/2;
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||||
|
|
||||||
|
if (i0 >= p.n_dims) {
|
||||||
|
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||||
|
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||||
|
|
|
||||||
|
|
@ -13,21 +13,19 @@ void main() {
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
const uint row_dst = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
|
||||||
const uint i = row_dst*ne0 + i0;
|
|
||||||
|
|
||||||
data_d[i + 0] = data_a[i + 0];
|
|
||||||
data_d[i + 1] = data_a[i + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
const uint row_x = row_dst % ne1;
|
||||||
const uint channel_x = row_dst / ne1;
|
const uint channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0;
|
const uint idst = row_dst*ne0 + i0;
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||||
|
|
||||||
|
if (i0 >= p.n_dims) {
|
||||||
|
data_d[idst + 0] = data_a[ix + 0];
|
||||||
|
data_d[idst + 1] = data_a[ix + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||||
|
|
|
||||||
|
|
@ -357,6 +357,8 @@ class MODEL_ARCH(IntEnum):
|
||||||
DOTS1 = auto()
|
DOTS1 = auto()
|
||||||
ARCEE = auto()
|
ARCEE = auto()
|
||||||
ERNIE4_5 = auto()
|
ERNIE4_5 = auto()
|
||||||
|
HUNYUAN_MOE = auto()
|
||||||
|
SMOLLM3 = auto()
|
||||||
|
|
||||||
|
|
||||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||||
|
|
@ -660,6 +662,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.DOTS1: "dots1",
|
MODEL_ARCH.DOTS1: "dots1",
|
||||||
MODEL_ARCH.ARCEE: "arcee",
|
MODEL_ARCH.ARCEE: "arcee",
|
||||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||||
|
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||||
|
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||||
}
|
}
|
||||||
|
|
||||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||||
|
|
@ -2211,6 +2215,43 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.HUNYUAN_MOE: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
|
],
|
||||||
|
MODEL_ARCH.SMOLLM3: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -303,6 +303,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||||
"model.layers.{bid}.feed_forward.router", # llama4
|
"model.layers.{bid}.feed_forward.router", # llama4
|
||||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||||
|
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||||
|
|
@ -362,6 +363,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||||
|
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||||
),
|
),
|
||||||
|
|
||||||
# AWQ-activation gate
|
# AWQ-activation gate
|
||||||
|
|
@ -398,6 +400,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||||
|
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward down
|
# Feed-forward down
|
||||||
|
|
@ -447,11 +450,13 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||||
|
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||||
|
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||||
|
|
@ -461,6 +466,7 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.ATTN_K_NORM: (
|
MODEL_TENSOR.ATTN_K_NORM: (
|
||||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||||
|
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,7 @@ extern "C" {
|
||||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_rope_type {
|
enum llama_rope_type {
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_DOTS1, "dots1" },
|
{ LLM_ARCH_DOTS1, "dots1" },
|
||||||
{ LLM_ARCH_ARCEE, "arcee" },
|
{ LLM_ARCH_ARCEE, "arcee" },
|
||||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||||
|
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||||
|
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1694,12 +1696,52 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_HUNYUAN_MOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_SMOLLM3,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,8 @@ enum llm_arch {
|
||||||
LLM_ARCH_DOTS1,
|
LLM_ARCH_DOTS1,
|
||||||
LLM_ARCH_ARCEE,
|
LLM_ARCH_ARCEE,
|
||||||
LLM_ARCH_ERNIE4_5,
|
LLM_ARCH_ERNIE4_5,
|
||||||
|
LLM_ARCH_HUNYUAN_MOE,
|
||||||
|
LLM_ARCH_SMOLLM3,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||||
|
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||||
};
|
};
|
||||||
|
|
||||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||||
|
|
@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||||
|
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||||
}
|
}
|
||||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|response|>";
|
ss << "<|response|>";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
||||||
|
// tencent/Hunyuan-A13B-Instruct
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
||||||
|
} else {
|
||||||
|
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ enum llm_chat_template {
|
||||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||||
LLM_CHAT_TEMPLATE_DOTS1,
|
LLM_CHAT_TEMPLATE_DOTS1,
|
||||||
|
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ const char * llm_type_name(llm_type type) {
|
||||||
case LLM_TYPE_57B_A14B: return "57B.A14B";
|
case LLM_TYPE_57B_A14B: return "57B.A14B";
|
||||||
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
||||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||||
|
case LLM_TYPE_A13B: return "A13B";
|
||||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||||
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
||||||
case LLM_TYPE_E2B: return "E2B";
|
case LLM_TYPE_E2B: return "E2B";
|
||||||
|
|
@ -1551,6 +1552,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_HUNYUAN_MOE:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 32: type = LLM_TYPE_A13B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_SMOLLM3:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
hparams.n_no_rope_layer_step = 4;
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 36: type = LLM_TYPE_3B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default: throw std::runtime_error("unsupported model architecture");
|
default: throw std::runtime_error("unsupported model architecture");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4471,6 +4493,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||||
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||||
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_HUNYUAN_MOE:
|
||||||
|
{
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||||
|
// if output is NULL, init from the input tok embed
|
||||||
|
if (output == NULL) {
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||||
|
|
||||||
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||||
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||||
|
|
||||||
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||||
|
|
||||||
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||||
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||||
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_SMOLLM3:
|
||||||
|
{
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
// if output is NULL, init from the input tok embed
|
||||||
|
if (output == NULL) {
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||||
|
|
@ -14647,6 +14735,304 @@ struct llm_build_arcee : public llm_graph_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llm_build_hunyuan_moe : public llm_graph_context {
|
||||||
|
llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv_unified();
|
||||||
|
|
||||||
|
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||||
|
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
|
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||||
|
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, Qcur, inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, Kcur, inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = build_norm(Kcur,
|
||||||
|
model.layers[il].attn_k_norm, nullptr,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(Kcur, "Kcur_norm", il);
|
||||||
|
|
||||||
|
Qcur = build_norm(Qcur,
|
||||||
|
model.layers[il].attn_q_norm, nullptr,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(Qcur, "Qcur_norm", il);
|
||||||
|
|
||||||
|
cur = build_attn(inp_attn, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||||
|
cb(cur, "attn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
cur = build_norm(ffn_inp,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
// feed-forward network (non-MoE)
|
||||||
|
ggml_tensor * cur_mlp = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||||
|
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||||
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(cur_mlp, "ffn_mlp", il);
|
||||||
|
|
||||||
|
// MoE branch
|
||||||
|
ggml_tensor * cur_moe = build_moe_ffn(cur,
|
||||||
|
model.layers[il].ffn_gate_inp,
|
||||||
|
model.layers[il].ffn_up_exps,
|
||||||
|
model.layers[il].ffn_gate_exps,
|
||||||
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
|
n_expert, n_expert_used,
|
||||||
|
LLM_FFN_SILU,
|
||||||
|
true, // norm_topk_prob
|
||||||
|
false,
|
||||||
|
0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
|
il);
|
||||||
|
cb(cur_moe, "ffn_moe_out", il);
|
||||||
|
|
||||||
|
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
|
||||||
|
cb(ffn_out, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, ffn_out, ffn_inp);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_build_smollm3 : public llm_graph_context {
|
||||||
|
llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv_unified();
|
||||||
|
|
||||||
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
if (use_rope) {
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
cur = build_attn(inp_attn, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||||
|
cb(cur, "attn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
{
|
||||||
|
cur = build_norm(ffn_inp,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||||
llama_memory_i * res;
|
llama_memory_i * res;
|
||||||
|
|
||||||
|
|
@ -15027,6 +15413,14 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
|
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_HUNYUAN_MOE:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_SMOLLM3:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
@ -15178,6 +15572,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_CHAMELEON:
|
case LLM_ARCH_CHAMELEON:
|
||||||
case LLM_ARCH_BAILINGMOE:
|
case LLM_ARCH_BAILINGMOE:
|
||||||
case LLM_ARCH_NEO_BERT:
|
case LLM_ARCH_NEO_BERT:
|
||||||
|
case LLM_ARCH_SMOLLM3:
|
||||||
case LLM_ARCH_ARCEE:
|
case LLM_ARCH_ARCEE:
|
||||||
case LLM_ARCH_ERNIE4_5:
|
case LLM_ARCH_ERNIE4_5:
|
||||||
return LLAMA_ROPE_TYPE_NORM;
|
return LLAMA_ROPE_TYPE_NORM;
|
||||||
|
|
@ -15215,6 +15610,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_EXAONE:
|
case LLM_ARCH_EXAONE:
|
||||||
case LLM_ARCH_MINICPM3:
|
case LLM_ARCH_MINICPM3:
|
||||||
case LLM_ARCH_DOTS1:
|
case LLM_ARCH_DOTS1:
|
||||||
|
case LLM_ARCH_HUNYUAN_MOE:
|
||||||
return LLAMA_ROPE_TYPE_NEOX;
|
return LLAMA_ROPE_TYPE_NEOX;
|
||||||
|
|
||||||
case LLM_ARCH_QWEN2VL:
|
case LLM_ARCH_QWEN2VL:
|
||||||
|
|
|
||||||
|
|
@ -94,6 +94,7 @@ enum llm_type {
|
||||||
LLM_TYPE_57B_A14B,
|
LLM_TYPE_57B_A14B,
|
||||||
LLM_TYPE_17B_16E, // llama4 Scout
|
LLM_TYPE_17B_16E, // llama4 Scout
|
||||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||||
|
LLM_TYPE_A13B,
|
||||||
LLM_TYPE_30B_A3B,
|
LLM_TYPE_30B_A3B,
|
||||||
LLM_TYPE_235B_A22B,
|
LLM_TYPE_235B_A22B,
|
||||||
LLM_TYPE_E2B,
|
LLM_TYPE_E2B,
|
||||||
|
|
|
||||||
|
|
@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
// original regex from tokenizer.json
|
// original regex from tokenizer.json
|
||||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
|
@ -1656,6 +1657,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
tokenizer_pre == "seed-coder") {
|
tokenizer_pre == "seed-coder") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "hunyuan") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
||||||
|
clean_spaces = false;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5323,12 +5323,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
for (bool fw : {true, false}) { // fw == forward
|
for (bool fw : {true, false}) { // fw == forward
|
||||||
bool all = true;
|
bool all = true;
|
||||||
|
|
||||||
for (float v : { 0, 1 }) {
|
for (float fs : { 1.0f, 1.4245f }) {
|
||||||
for (float fs : { 1.0f, 1.4245f }) {
|
for (float ef : { 0.0f, 0.7465f }) {
|
||||||
for (float ef : { 0.0f, 0.7465f }) {
|
for (float af : { 1.0f, 1.4245f }) {
|
||||||
for (float af : { 1.0f, 1.4245f }) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (bool ff : {false, true}) { // freq_factors
|
||||||
for (bool ff : {false, true}) { // freq_factors
|
for (float v : { 0, 1 }) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
|
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
|
|
@ -5341,13 +5341,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
|
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
|
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4806,14 +4806,14 @@ int main(int argc, char ** argv) {
|
||||||
// register static assets routes
|
// register static assets routes
|
||||||
if (!params.public_path.empty()) {
|
if (!params.public_path.empty()) {
|
||||||
// Set the base directory for serving static files
|
// Set the base directory for serving static files
|
||||||
bool is_found = svr->set_mount_point("/", params.public_path);
|
bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path);
|
||||||
if (!is_found) {
|
if (!is_found) {
|
||||||
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// using embedded static index.html
|
// using embedded static index.html
|
||||||
svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
|
svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
|
||||||
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||||
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -4829,37 +4829,37 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// register API routes
|
// register API routes
|
||||||
svr->Get ("/health", handle_health); // public endpoint (no API key check)
|
svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
|
||||||
svr->Get ("/metrics", handle_metrics);
|
svr->Get (params.api_prefix + "/metrics", handle_metrics);
|
||||||
svr->Get ("/props", handle_props);
|
svr->Get (params.api_prefix + "/props", handle_props);
|
||||||
svr->Post("/props", handle_props_change);
|
svr->Post(params.api_prefix + "/props", handle_props_change);
|
||||||
svr->Post("/api/show", handle_api_show);
|
svr->Post(params.api_prefix + "/api/show", handle_api_show);
|
||||||
svr->Get ("/models", handle_models); // public endpoint (no API key check)
|
svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check)
|
||||||
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check)
|
||||||
svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
|
svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
|
||||||
svr->Post("/completion", handle_completions); // legacy
|
svr->Post(params.api_prefix + "/completion", handle_completions); // legacy
|
||||||
svr->Post("/completions", handle_completions);
|
svr->Post(params.api_prefix + "/completions", handle_completions);
|
||||||
svr->Post("/v1/completions", handle_completions_oai);
|
svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai);
|
||||||
svr->Post("/chat/completions", handle_chat_completions);
|
svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions);
|
||||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions);
|
||||||
svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint
|
svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint
|
||||||
svr->Post("/infill", handle_infill);
|
svr->Post(params.api_prefix + "/infill", handle_infill);
|
||||||
svr->Post("/embedding", handle_embeddings); // legacy
|
svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy
|
||||||
svr->Post("/embeddings", handle_embeddings);
|
svr->Post(params.api_prefix + "/embeddings", handle_embeddings);
|
||||||
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai);
|
||||||
svr->Post("/rerank", handle_rerank);
|
svr->Post(params.api_prefix + "/rerank", handle_rerank);
|
||||||
svr->Post("/reranking", handle_rerank);
|
svr->Post(params.api_prefix + "/reranking", handle_rerank);
|
||||||
svr->Post("/v1/rerank", handle_rerank);
|
svr->Post(params.api_prefix + "/v1/rerank", handle_rerank);
|
||||||
svr->Post("/v1/reranking", handle_rerank);
|
svr->Post(params.api_prefix + "/v1/reranking", handle_rerank);
|
||||||
svr->Post("/tokenize", handle_tokenize);
|
svr->Post(params.api_prefix + "/tokenize", handle_tokenize);
|
||||||
svr->Post("/detokenize", handle_detokenize);
|
svr->Post(params.api_prefix + "/detokenize", handle_detokenize);
|
||||||
svr->Post("/apply-template", handle_apply_template);
|
svr->Post(params.api_prefix + "/apply-template", handle_apply_template);
|
||||||
// LoRA adapters hotswap
|
// LoRA adapters hotswap
|
||||||
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list);
|
||||||
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply);
|
||||||
// Save & load slots
|
// Save & load slots
|
||||||
svr->Get ("/slots", handle_slots);
|
svr->Get (params.api_prefix + "/slots", handle_slots);
|
||||||
svr->Post("/slots/:id_slot", handle_slots_action);
|
svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Start the server
|
// Start the server
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue