Merge branch 'master' into xsn/server_model_management_v1_2
This commit is contained in:
commit
bdaf44a13c
|
|
@ -50,6 +50,7 @@ WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y \
|
&& apt-get install -y \
|
||||||
|
build-essential \
|
||||||
git \
|
git \
|
||||||
python3 \
|
python3 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
|
|
||||||
|
|
@ -699,6 +699,12 @@ static bool is_autoy(const std::string & value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||||
|
// default values specific to example
|
||||||
|
// note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up
|
||||||
|
if (ex == LLAMA_EXAMPLE_SERVER) {
|
||||||
|
params.use_jinja = true;
|
||||||
|
}
|
||||||
|
|
||||||
// load dynamic backends
|
// load dynamic backends
|
||||||
ggml_backend_load_all();
|
ggml_backend_load_all();
|
||||||
|
|
||||||
|
|
@ -2511,11 +2517,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--jinja"},
|
{"--jinja"},
|
||||||
"use jinja template for chat (default: disabled)",
|
string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
|
||||||
[](common_params & params) {
|
[](common_params & params) {
|
||||||
params.use_jinja = true;
|
params.use_jinja = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--no-jinja"},
|
||||||
|
string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
|
||||||
|
[](common_params & params) {
|
||||||
|
params.use_jinja = false;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--reasoning-format"}, "FORMAT",
|
{"--reasoning-format"}, "FORMAT",
|
||||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||||
|
|
|
||||||
|
|
@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||||
super().set_vocab()
|
super().set_vocab()
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Qwen3NextForCausalLM")
|
||||||
|
class Qwen3NextModel(Qwen2MoeModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
|
||||||
|
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
|
||||||
|
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
|
||||||
|
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
|
||||||
|
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
|
||||||
|
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||||
|
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||||
|
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
if name.startswith("mtp"):
|
||||||
|
return [] # ignore MTP layers for now
|
||||||
|
if name.endswith(".A_log"):
|
||||||
|
data_torch = -torch.exp(data_torch)
|
||||||
|
elif name.endswith(".dt_bias"):
|
||||||
|
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
|
||||||
|
elif "conv1d" in name:
|
||||||
|
data_torch = data_torch.squeeze()
|
||||||
|
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
|
||||||
|
data_torch = data_torch + 1
|
||||||
|
|
||||||
|
yield from super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("RND1")
|
@ModelBase.register("RND1")
|
||||||
class RND1Model(Qwen2MoeModel):
|
class RND1Model(Qwen2MoeModel):
|
||||||
model_arch = gguf.MODEL_ARCH.RND1
|
model_arch = gguf.MODEL_ARCH.RND1
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,11 @@ set -e
|
||||||
|
|
||||||
# First try command line argument, then environment variable, then file
|
# First try command line argument, then environment variable, then file
|
||||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||||
|
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
|
||||||
|
|
||||||
|
if [ -z "$MODEL_TESTING_PROMPT"]; then
|
||||||
|
MODEL_TESTING_PROMPT="Hello, my name is"
|
||||||
|
fi
|
||||||
|
|
||||||
# Final check if we have a model path
|
# Final check if we have a model path
|
||||||
if [ -z "$CONVERTED_MODEL" ]; then
|
if [ -z "$CONVERTED_MODEL" ]; then
|
||||||
|
|
@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo $CONVERTED_MODEL
|
echo $CONVERTED_MODEL
|
||||||
|
echo $MODEL_TESTING_PROMPT
|
||||||
|
|
||||||
cmake --build ../../build --target llama-logits -j8
|
cmake --build ../../build --target llama-logits -j8
|
||||||
|
|
||||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
|
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
|
||||||
|
|
|
||||||
|
|
@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
|
||||||
# of using AutoModelForCausalLM.
|
# of using AutoModelForCausalLM.
|
||||||
print(f"Model class: {model.__class__.__name__}")
|
print(f"Model class: {model.__class__.__name__}")
|
||||||
|
|
||||||
prompt = "Hello, my name is"
|
device = next(model.parameters()).device
|
||||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
if os.getenv("MODEL_TESTING_PROMPT"):
|
||||||
|
prompt = os.getenv("MODEL_TESTING_PROMPT")
|
||||||
|
else:
|
||||||
|
prompt = "Hello, my name is"
|
||||||
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||||
|
|
||||||
print(f"Input tokens: {input_ids}")
|
print(f"Input tokens: {input_ids}")
|
||||||
print(f"Input text: {repr(prompt)}")
|
print(f"Input text: {repr(prompt)}")
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define RPC_PROTO_MAJOR_VERSION 3
|
#define RPC_PROTO_MAJOR_VERSION 3
|
||||||
#define RPC_PROTO_MINOR_VERSION 0
|
#define RPC_PROTO_MINOR_VERSION 5
|
||||||
#define RPC_PROTO_PATCH_VERSION 0
|
#define RPC_PROTO_PATCH_VERSION 0
|
||||||
#define GGML_RPC_MAX_SERVERS 16
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,10 +33,12 @@
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -44,12 +46,14 @@
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
|
|
@ -58,11 +62,14 @@
|
||||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#elif defined(__POWERPC__) || defined(__powerpc__)
|
#elif defined(__POWERPC__) || defined(__powerpc__)
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
|
||||||
|
|
@ -74,10 +81,12 @@
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -85,6 +94,7 @@
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -99,10 +109,12 @@
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -110,6 +122,7 @@
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -132,15 +145,18 @@
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -161,10 +177,12 @@
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -172,6 +190,7 @@
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -194,10 +213,12 @@
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
|
@ -205,6 +226,7 @@
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
|
|
||||||
|
|
@ -497,6 +497,140 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
constexpr int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
|
||||||
|
constexpr int ncols_interleaved = 8;
|
||||||
|
constexpr int blocklen = 8;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
|
||||||
|
// 1x8 tile = 2 x 4
|
||||||
|
float32x4_t acc_f32[col_groups];
|
||||||
|
|
||||||
|
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||||
|
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int i = 0; i < col_groups; i++) {
|
||||||
|
acc_f32[i] = vdupq_n_f32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nb; b++) {
|
||||||
|
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
||||||
|
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||||
|
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||||
|
float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
|
||||||
|
float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
|
||||||
|
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
||||||
|
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
||||||
|
float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
|
||||||
|
float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
|
||||||
|
|
||||||
|
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||||
|
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||||
|
int32x4_t acc_lo[col_groups];
|
||||||
|
int32x4_t acc_hi[col_groups];
|
||||||
|
|
||||||
|
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||||
|
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||||
|
int16_t bsums_arr[8];
|
||||||
|
vst1q_s16(bsums_arr, bsums);
|
||||||
|
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||||
|
for (int i = 0; i < col_groups; i++) {
|
||||||
|
acc_lo[i] = vdupq_n_s32(0);
|
||||||
|
acc_hi[i] = vdupq_n_s32(0);
|
||||||
|
}
|
||||||
|
// Need scales for the low and high nibbles
|
||||||
|
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||||
|
int16x8_t q4sb_mins[2];
|
||||||
|
int16x8_t q4sb_scales[2];
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
int8_t aux_q4sb[8];
|
||||||
|
const int offset = sb * 24 + i * 12;
|
||||||
|
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||||
|
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||||
|
}
|
||||||
|
|
||||||
|
int8x16_t q8_qs[64 / 16];
|
||||||
|
for (int i = 0; i < 64 / 16; i++) {
|
||||||
|
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int c = 0; c < col_groups; c++) {
|
||||||
|
uint8x16_t q4_cols[8];
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||||
|
}
|
||||||
|
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
|
||||||
|
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scales
|
||||||
|
// row c0123 blk0 and blk1
|
||||||
|
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
||||||
|
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
||||||
|
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
||||||
|
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
||||||
|
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
||||||
|
// row c4567 blk0 and blk1
|
||||||
|
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
||||||
|
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
||||||
|
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
||||||
|
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
||||||
|
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
||||||
|
|
||||||
|
// Bias Correction
|
||||||
|
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||||
|
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||||
|
|
||||||
|
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||||
|
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||||
|
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||||
|
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||||
|
} // for sb
|
||||||
|
|
||||||
|
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
||||||
|
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
||||||
|
} // for b
|
||||||
|
|
||||||
|
int base = x * ncols_interleaved;
|
||||||
|
vst1q_f32(s + base, acc_f32[0]);
|
||||||
|
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||||
|
} // for x
|
||||||
|
return;
|
||||||
|
#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemv_q4_K_8x8_q8_K(int n,
|
void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||||
float * GGML_RESTRICT s,
|
float * GGML_RESTRICT s,
|
||||||
size_t bs,
|
size_t bs,
|
||||||
|
|
@ -518,7 +652,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__aarch64__) && defined(__ARM_NEON)
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
constexpr int col_pairs = ncols_interleaved / 2;
|
constexpr int col_pairs = ncols_interleaved / 2;
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
|
||||||
|
|
@ -615,7 +749,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||||
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
||||||
|
|
||||||
// 0123 or 4567
|
// 0123 or 4567
|
||||||
// TODO: Single superblock mul at the end of the superblock
|
|
||||||
float32x4_t sumf_0 =
|
float32x4_t sumf_0 =
|
||||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
||||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
||||||
|
|
@ -649,7 +782,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||||
} // for x
|
} // for x
|
||||||
return;
|
return;
|
||||||
#endif // defined(__aarch64__) && defined(__ARM_NEON)
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2069,6 +2202,206 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
constexpr int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
|
||||||
|
constexpr int ncols_interleaved = 8;
|
||||||
|
constexpr int blocklen = 4;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
constexpr int q8_k_blocklen = 4;
|
||||||
|
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
|
||||||
|
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||||
|
float32x4_t acc_f32[acc_size];
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||||
|
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||||
|
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int i = 0; i < acc_size; i++) {
|
||||||
|
acc_f32[i] = vdupq_n_f32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nb; b++) {
|
||||||
|
// d4 0 1 2 3, 4 5 6 7
|
||||||
|
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
|
||||||
|
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
|
||||||
|
// d8 0 1 2 3
|
||||||
|
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||||
|
// mins
|
||||||
|
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
|
||||||
|
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
|
||||||
|
|
||||||
|
// Precomputation of scales and mins
|
||||||
|
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||||
|
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||||
|
float32x4_t sbd_min_0123[q8_k_blocklen];
|
||||||
|
float32x4_t sbd_min_4567[q8_k_blocklen];
|
||||||
|
|
||||||
|
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
|
||||||
|
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
|
||||||
|
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
|
||||||
|
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
|
||||||
|
|
||||||
|
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
|
||||||
|
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
|
||||||
|
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
|
||||||
|
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
|
||||||
|
|
||||||
|
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
|
||||||
|
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
|
||||||
|
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
|
||||||
|
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
|
||||||
|
|
||||||
|
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
|
||||||
|
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
|
||||||
|
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
|
||||||
|
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
|
||||||
|
|
||||||
|
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
||||||
|
const int16x8_t bsums[q8_k_blocklen] = {
|
||||||
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||||
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||||
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||||
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||||
|
};
|
||||||
|
int16_t bsums_arr[QK_K / 64][8];
|
||||||
|
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||||
|
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
||||||
|
int32x4_t bias_acc[acc_size];
|
||||||
|
for (int i = 0; i < acc_size; i++) {
|
||||||
|
bias_acc[i] = vdupq_n_s32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||||
|
// Int accumulators for qs vecdot (4 row x 2 col quartets)
|
||||||
|
int32x4_t acc_lo[acc_size];
|
||||||
|
int32x4_t acc_hi[acc_size];
|
||||||
|
for (int i = 0; i < acc_size; i++) {
|
||||||
|
acc_lo[i] = vdupq_n_s32(0);
|
||||||
|
acc_hi[i] = vdupq_n_s32(0);
|
||||||
|
}
|
||||||
|
// Need scales for the low and high nibbles
|
||||||
|
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||||
|
int16x8_t q4sb_scales[2];
|
||||||
|
int16x8_t q4sb_mins[2];
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
int8_t aux_q4sb[8];
|
||||||
|
const int offset = sb * 24 + i * 12;
|
||||||
|
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||||
|
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
||||||
|
for (int k = 0; k < reads_per_sb; k++) {
|
||||||
|
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
||||||
|
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
||||||
|
|
||||||
|
// 0..3 & 32..35
|
||||||
|
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
|
||||||
|
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
||||||
|
|
||||||
|
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
|
||||||
|
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
|
||||||
|
|
||||||
|
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
||||||
|
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
||||||
|
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
||||||
|
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
||||||
|
|
||||||
|
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
||||||
|
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
||||||
|
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
||||||
|
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
||||||
|
|
||||||
|
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
|
||||||
|
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
|
||||||
|
|
||||||
|
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
||||||
|
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
||||||
|
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
||||||
|
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
||||||
|
|
||||||
|
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
||||||
|
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
||||||
|
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
||||||
|
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale and bias application
|
||||||
|
// acc is stored interleaved to match output layout
|
||||||
|
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
||||||
|
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
||||||
|
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
||||||
|
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
||||||
|
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||||
|
// Bias correction
|
||||||
|
// row c0123 blk0 and blk1
|
||||||
|
const float32x4_t sumf_0123 =
|
||||||
|
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
||||||
|
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
||||||
|
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
||||||
|
|
||||||
|
// row c4567 blk0 and blk1
|
||||||
|
const float32x4_t sumf_4567 =
|
||||||
|
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
||||||
|
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
||||||
|
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
||||||
|
|
||||||
|
// Bias
|
||||||
|
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
||||||
|
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
||||||
|
|
||||||
|
// row c0123 blk0 and blk1
|
||||||
|
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
||||||
|
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
||||||
|
|
||||||
|
// row c4567 blk0 and blk1
|
||||||
|
bias_acc[2 * row + 1] =
|
||||||
|
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
||||||
|
bias_acc[2 * row + 1] =
|
||||||
|
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
||||||
|
}
|
||||||
|
} // for sb
|
||||||
|
|
||||||
|
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||||
|
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
||||||
|
acc_f32[2 * row + 1] =
|
||||||
|
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
||||||
|
}
|
||||||
|
} // for b
|
||||||
|
|
||||||
|
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||||
|
int row = y * q8_k_blocklen + i;
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
int col = x * ncols_interleaved + j * 4;
|
||||||
|
int offset = row * bs + col;
|
||||||
|
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // for x
|
||||||
|
} // for y
|
||||||
|
return;
|
||||||
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||||
float * GGML_RESTRICT s,
|
float * GGML_RESTRICT s,
|
||||||
size_t bs,
|
size_t bs,
|
||||||
|
|
|
||||||
|
|
@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
|
||||||
}
|
}
|
||||||
|
|
||||||
const float diag = A_batch[i00 * n + i00];
|
const float diag = A_batch[i00 * n + i00];
|
||||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||||
|
|
||||||
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||||
|
assert(QK_K == 256);
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||||
|
|
||||||
|
// scalar
|
||||||
|
const int blck_size_interleave = 4;
|
||||||
|
float srcv[4][QK_K];
|
||||||
|
float iscale[4];
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
float max = 0;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K; j++) {
|
||||||
|
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
|
||||||
|
// Update the maximum value of the corresponding super block
|
||||||
|
if(amax < fabsf(srcv[row_iter][j])) {
|
||||||
|
amax = fabsf(srcv[row_iter][j]);
|
||||||
|
max = srcv[row_iter][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
iscale[row_iter] = amax ? -127.f/max : 0;
|
||||||
|
|
||||||
|
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K / 4; j++) {
|
||||||
|
y[i].bsums[j] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quants values are interleaved in sequence of four bytes from corresponding super blocks
|
||||||
|
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
|
||||||
|
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
|
||||||
|
for (int j = 0; j < QK_K * 4; j++) {
|
||||||
|
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
|
||||||
|
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
|
||||||
|
src_offset += (j % blck_size_interleave);
|
||||||
|
int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
|
||||||
|
|
||||||
|
float x0 = srcv[src_id][src_offset] * iscale[src_id];
|
||||||
|
y[i].qs[j] = nearest_int(x0);
|
||||||
|
y[i].bsums[index] += y[i].qs[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||||
assert(QK_K == 256);
|
assert(QK_K == 256);
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
|
|
@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
|
||||||
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||||
|
assert(nrow == 4);
|
||||||
|
UNUSED(nrow);
|
||||||
|
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
|
||||||
|
}
|
||||||
|
|
||||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||||
assert(nrow == 4);
|
assert(nrow == 4);
|
||||||
UNUSED(nrow);
|
UNUSED(nrow);
|
||||||
|
|
@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
const int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 8;
|
||||||
|
const int blocklen = 4;
|
||||||
|
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||||
|
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||||
|
static const uint32_t kmask3 = 0x03030303;
|
||||||
|
|
||||||
|
assert (n % qk == 0);
|
||||||
|
assert (nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(nr);
|
||||||
|
|
||||||
|
float sumf[8];
|
||||||
|
float sum_minf[8];
|
||||||
|
uint32_t utmp[32];
|
||||||
|
int sumi1;
|
||||||
|
int sumi2;
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[j] = 0.0;
|
||||||
|
sum_minf[j] = 0.0;
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int sb = 0; sb < 8; sb++) {
|
||||||
|
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||||
|
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||||
|
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||||
|
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||||
|
utmp[sb * 4 + 2] = uaux_0;
|
||||||
|
utmp[sb * 4 + 0] &= kmask1;
|
||||||
|
}
|
||||||
|
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||||
|
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
|
||||||
|
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi1 = 0;
|
||||||
|
sumi2 = 0;
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
||||||
|
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
||||||
|
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
|
||||||
|
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
|
||||||
|
sumi1 = sumi1 * scales_0[j];
|
||||||
|
sumi2 = sumi2 * scales_1[j];
|
||||||
|
sumi += sumi1 + sumi2;
|
||||||
|
}
|
||||||
|
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int sb = 0; sb < 8; sb++) {
|
||||||
|
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK_K;
|
const int qk = QK_K;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
const int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 8;
|
||||||
|
const int blocklen = 4;
|
||||||
|
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||||
|
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||||
|
static const uint32_t kmask3 = 0x03030303;
|
||||||
|
|
||||||
|
assert (n % qk == 0);
|
||||||
|
assert (nr % 4 == 0);
|
||||||
|
assert (nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
float sumf[4][8];
|
||||||
|
float sum_minf[4][8];
|
||||||
|
uint32_t utmp[32];
|
||||||
|
int sumi1;
|
||||||
|
int sumi2;
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[m][j] = 0.0;
|
||||||
|
sum_minf[m][j] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int sb = 0; sb < 8; sb++) {
|
||||||
|
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||||
|
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||||
|
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||||
|
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||||
|
utmp[sb * 4 + 2] = uaux_0;
|
||||||
|
utmp[sb * 4 + 0] &= kmask1;
|
||||||
|
}
|
||||||
|
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||||
|
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
|
||||||
|
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi1 = 0;
|
||||||
|
sumi2 = 0;
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
||||||
|
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
||||||
|
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
|
||||||
|
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
|
||||||
|
sumi1 = sumi1 * scales_0[j];
|
||||||
|
sumi2 = sumi2 * scales_1[j];
|
||||||
|
sumi += sumi1 + sumi2;
|
||||||
|
}
|
||||||
|
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int sb = 0; sb < 8; sb++) {
|
||||||
|
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||||
|
for(int m = 0; m < 4; m++) {
|
||||||
|
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||||
|
for(int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK_K;
|
const int qk = QK_K;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
||||||
|
|
||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
||||||
GGML_ASSERT(interleave_block == 8);
|
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
|
||||||
constexpr int nrows_interleaved = 8;
|
constexpr int nrows_interleaved = 8;
|
||||||
|
|
||||||
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
|
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
|
||||||
|
|
@ -1468,6 +1681,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||||
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
|
return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
|
||||||
|
}
|
||||||
|
|
||||||
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||||
}
|
}
|
||||||
|
|
@ -1501,6 +1718,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
@ -1529,6 +1750,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
@ -1731,12 +1956,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||||
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nth == 1 || nchunk0 < nth || disable_chunking) {
|
int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||||
|
// Only increase nchunk0 to nth if it won't make chunks too small
|
||||||
|
if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
|
||||||
nchunk0 = nth;
|
nchunk0 = nth;
|
||||||
|
dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
|
||||||
|
|
||||||
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
||||||
// This prevents creating too many tiny chunks that could overlap after alignment
|
// This prevents creating too many tiny chunks that could overlap after alignment
|
||||||
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||||
|
|
@ -1930,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
||||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
||||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||||
|
|
||||||
|
// instance for Q4_K
|
||||||
|
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||||
|
|
||||||
// instance for Q2
|
// instance for Q2
|
||||||
|
|
@ -1966,6 +2195,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
return &q4_K_8x8_q8_K;
|
return &q4_K_8x8_q8_K;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||||
|
if (cur->ne[1] % 8 == 0) {
|
||||||
|
return &q4_K_8x4_q8_K;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (cur->type == GGML_TYPE_Q2_K) {
|
} else if (cur->type == GGML_TYPE_Q2_K) {
|
||||||
if (ggml_cpu_has_avx512()) {
|
if (ggml_cpu_has_avx512()) {
|
||||||
if (cur->ne[1] % 8 == 0) {
|
if (cur->ne[1] % 8 == 0) {
|
||||||
|
|
|
||||||
|
|
@ -80,10 +80,12 @@ extern "C" {
|
||||||
|
|
||||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
|
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
// Native implementations
|
// Native implementations
|
||||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
|
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
|
||||||
|
|
@ -558,8 +558,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
|
||||||
acc += v.y*u.y;
|
acc += v.y*u.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
|
||||||
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||||
|
#define V_DOT2_F32_F16_AVAILABLE
|
||||||
|
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||||
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
||||||
#else
|
#else
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
|
@ -571,7 +575,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
||||||
acc += tmpv.x * tmpu.x;
|
acc += tmpv.x * tmpu.x;
|
||||||
acc += tmpv.y * tmpu.y;
|
acc += tmpv.y * tmpu.y;
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
||||||
|
|
|
||||||
|
|
@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
||||||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||||
#else
|
#else
|
||||||
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||||
#endif // FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||||
#else
|
#else
|
||||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
|
|
||||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
|
||||||
|
|
@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
constexpr int ne_KQ = ncols*D;
|
constexpr int ne_KQ = ncols*D;
|
||||||
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||||
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||||
#else
|
#else
|
||||||
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||||
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
|
|
||||||
float KQ_max[ncols];
|
float KQ_max[ncols];
|
||||||
float KQ_sum[ncols];
|
float KQ_sum[ncols];
|
||||||
|
|
@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
||||||
#else
|
#else
|
||||||
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||||
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||||
if constexpr (Q_q8_1) {
|
if constexpr (Q_q8_1) {
|
||||||
|
|
@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
} else {
|
} else {
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
const half2 scale_h2 = make_half2(scale, scale);
|
const half2 scale_h2 = make_half2(scale, scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
|
@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
Q_reg[j][k].y *= scale;
|
Q_reg[j][k].y *= scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||||
|
|
@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||||
|
|
@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef GGML_USE_HIP
|
#ifndef GGML_USE_HIP
|
||||||
|
|
@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||||
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 KQ_k[ncols];
|
half2 KQ_k[ncols];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
|
@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
|
|
||||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||||
|
|
@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
||||||
KQ_max[j_VKQ] = kqmax_new;
|
KQ_max[j_VKQ] = kqmax_new;
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||||
|
|
||||||
|
|
@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
|
||||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||||
}
|
}
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||||
|
|
||||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@
|
||||||
#include "ggml-cuda/set.cuh"
|
#include "ggml-cuda/set.cuh"
|
||||||
#include "ggml-cuda/set-rows.cuh"
|
#include "ggml-cuda/set-rows.cuh"
|
||||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||||
|
#include "ggml-cuda/solve_tri.cuh"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
ggml_cuda_opt_step_sgd(ctx, dst);
|
ggml_cuda_opt_step_sgd(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
ggml_cuda_op_solve_tri(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -3837,7 +3841,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
|
||||||
|
|
||||||
// Check if UMA is explicitly enabled via environment variable
|
// Check if UMA is explicitly enabled via environment variable
|
||||||
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
|
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
|
||||||
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
|
bool is_uma = prop.integrated > 0 || uma_env;
|
||||||
|
|
||||||
if (is_uma) {
|
if (is_uma) {
|
||||||
// For UMA systems (like DGX Spark), use system memory info
|
// For UMA systems (like DGX Spark), use system memory info
|
||||||
|
|
@ -4255,6 +4259,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
if (src1_ncols > 16) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,203 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "solve_tri.cuh"
|
||||||
|
|
||||||
|
#define MAX_N_FAST 64
|
||||||
|
#define MAX_K_FAST 32
|
||||||
|
|
||||||
|
// ======================
|
||||||
|
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
||||||
|
// ======================
|
||||||
|
// When ncols_template == 0 the bounds for the loops in this function are not
|
||||||
|
// known and can't be unrolled. As we want to keep pragma unroll for all other
|
||||||
|
// cases we supress the clang transformation warning here.
|
||||||
|
#ifdef __clang__
|
||||||
|
# pragma clang diagnostic push
|
||||||
|
# pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
|
#endif // __clang__
|
||||||
|
template <int n_template, int k_template>
|
||||||
|
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||||||
|
const float * __restrict__ B,
|
||||||
|
float * __restrict__ X,
|
||||||
|
const uint3 ne02,
|
||||||
|
const size_t nb02,
|
||||||
|
const size_t nb03,
|
||||||
|
const size_t nb12,
|
||||||
|
const size_t nb13,
|
||||||
|
const size_t nb2,
|
||||||
|
const size_t nb3,
|
||||||
|
const int n_arg,
|
||||||
|
const int k_arg) {
|
||||||
|
const int n = n_template == 0 ? n_arg : n_template;
|
||||||
|
const int k = k_template == 0 ? k_arg : k_template;
|
||||||
|
|
||||||
|
const int batch_idx = blockIdx.x;
|
||||||
|
const int lane = threadIdx.x;
|
||||||
|
const int col_idx = threadIdx.y;
|
||||||
|
|
||||||
|
if (col_idx >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
|
const int64_t i02 = i02_i03.y;
|
||||||
|
const int64_t i03 = i02_i03.x;
|
||||||
|
|
||||||
|
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||||||
|
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||||||
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
||||||
|
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
|
||||||
|
|
||||||
|
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
||||||
|
int i0 = i + offset;
|
||||||
|
if (i0 < n * n) {
|
||||||
|
sA[i0] = A_batch[i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_warp; i++) {
|
||||||
|
const int i0 = lane + i * WARP_SIZE;
|
||||||
|
if (i0 < n) {
|
||||||
|
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int row = 0; row < n; ++row) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
{
|
||||||
|
int j = lane;
|
||||||
|
if (j < row) {
|
||||||
|
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (row >= WARP_SIZE) {
|
||||||
|
int j = WARP_SIZE + lane;
|
||||||
|
if (j < row) {
|
||||||
|
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (lane == 0) {
|
||||||
|
const float b_val = sXt[col_idx * n + row];
|
||||||
|
const float a_diag = sA[row * n + row];
|
||||||
|
// no safeguards for division by zero because that indicates corrupt
|
||||||
|
// data anyway
|
||||||
|
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_warp; i++) {
|
||||||
|
const int i0 = lane + i * WARP_SIZE;
|
||||||
|
if (i0 < n) {
|
||||||
|
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifdef __clang__
|
||||||
|
# pragma clang diagnostic pop
|
||||||
|
#endif // __clang__
|
||||||
|
|
||||||
|
static void solve_tri_f32_cuda(const float * A,
|
||||||
|
const float * B,
|
||||||
|
float * X,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
int64_t ne02,
|
||||||
|
int64_t ne03,
|
||||||
|
size_t nb02,
|
||||||
|
size_t nb03,
|
||||||
|
size_t nb12,
|
||||||
|
size_t nb13,
|
||||||
|
size_t nb2,
|
||||||
|
size_t nb3,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||||
|
dim3 threads(WARP_SIZE, k);
|
||||||
|
dim3 grid(ne02 * ne03);
|
||||||
|
if (n == 64) {
|
||||||
|
switch (k) {
|
||||||
|
case 32:
|
||||||
|
solve_tri_f32_fast<64, 32>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
solve_tri_f32_fast<64, 16>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 14:
|
||||||
|
solve_tri_f32_fast<64, 14>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 12:
|
||||||
|
solve_tri_f32_fast<64, 12>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 10:
|
||||||
|
solve_tri_f32_fast<64, 10>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
solve_tri_f32_fast<64, 8>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
solve_tri_f32_fast<64, 6>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
solve_tri_f32_fast<64, 4>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
solve_tri_f32_fast<64, 2>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
solve_tri_f32_fast<64, 1>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
solve_tri_f32_fast<0, 0>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
|
}
|
||||||
|
} else { // run general case
|
||||||
|
solve_tri_f32_fast<0, 0>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
|
||||||
|
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
|
||||||
|
|
||||||
|
ggml_is_contiguous(src0);
|
||||||
|
ggml_is_contiguous(src1);
|
||||||
|
|
||||||
|
const int64_t n = src0->ne[0];
|
||||||
|
const int64_t k = src1->ne[0];
|
||||||
|
|
||||||
|
GGML_ASSERT(n <= 64);
|
||||||
|
GGML_ASSERT(k <= 32);
|
||||||
|
|
||||||
|
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
|
||||||
|
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||||||
|
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
||||||
|
dst->nb[3] / sizeof(float), ctx.stream());
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS
|
||||||
group_norm
|
group_norm
|
||||||
im2col_f32
|
im2col_f32
|
||||||
im2col_f16
|
im2col_f16
|
||||||
|
mean
|
||||||
mul_mat_Ab_Bi_8x4
|
mul_mat_Ab_Bi_8x4
|
||||||
mul_mv_f16_f16
|
mul_mv_f16_f16
|
||||||
mul_mv_f16_f32_1row
|
mul_mv_f16_f32_1row
|
||||||
|
|
@ -109,6 +110,9 @@ set(GGML_OPENCL_KERNELS
|
||||||
softmax_4_f16
|
softmax_4_f16
|
||||||
softmax_f32
|
softmax_f32
|
||||||
softmax_f16
|
softmax_f16
|
||||||
|
sqr
|
||||||
|
sqrt
|
||||||
|
ssm_conv
|
||||||
sub
|
sub
|
||||||
sum_rows
|
sum_rows
|
||||||
transpose
|
transpose
|
||||||
|
|
|
||||||
|
|
@ -449,6 +449,9 @@ struct ggml_backend_opencl_context {
|
||||||
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
|
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
|
||||||
cl_kernel kernel_add_id;
|
cl_kernel kernel_add_id;
|
||||||
cl_kernel kernel_scale;
|
cl_kernel kernel_scale;
|
||||||
|
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
|
||||||
|
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
|
||||||
|
cl_kernel kernel_mean_f32;
|
||||||
cl_kernel kernel_silu, kernel_silu_4;
|
cl_kernel kernel_silu, kernel_silu_4;
|
||||||
cl_kernel kernel_gelu, kernel_gelu_4;
|
cl_kernel kernel_gelu, kernel_gelu_4;
|
||||||
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
|
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
|
||||||
|
|
@ -509,6 +512,7 @@ struct ggml_backend_opencl_context {
|
||||||
cl_kernel kernel_conv_2d_f16;
|
cl_kernel kernel_conv_2d_f16;
|
||||||
cl_kernel kernel_conv_2d_f32;
|
cl_kernel kernel_conv_2d_f32;
|
||||||
cl_kernel kernel_conv_2d_f16_f32;
|
cl_kernel kernel_conv_2d_f16_f32;
|
||||||
|
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
|
||||||
cl_kernel kernel_timestep_embedding;
|
cl_kernel kernel_timestep_embedding;
|
||||||
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
|
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
|
||||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||||
|
|
@ -1552,6 +1556,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||||
GGML_LOG_CONT(".");
|
GGML_LOG_CONT(".");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sqr
|
||||||
|
{
|
||||||
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
const std::string kernel_src {
|
||||||
|
#include "sqr.cl.h"
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
const std::string kernel_src = read_file("sqr.cl");
|
||||||
|
#endif
|
||||||
|
cl_program prog =
|
||||||
|
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||||
|
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err));
|
||||||
|
|
||||||
|
CL_CHECK(clReleaseProgram(prog));
|
||||||
|
GGML_LOG_CONT(".");
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqrt
|
||||||
|
{
|
||||||
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
const std::string kernel_src {
|
||||||
|
#include "sqrt.cl.h"
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
const std::string kernel_src = read_file("sqrt.cl");
|
||||||
|
#endif
|
||||||
|
cl_program prog =
|
||||||
|
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||||
|
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err));
|
||||||
|
|
||||||
|
CL_CHECK(clReleaseProgram(prog));
|
||||||
|
GGML_LOG_CONT(".");
|
||||||
|
}
|
||||||
|
|
||||||
|
// mean
|
||||||
|
{
|
||||||
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
const std::string kernel_src {
|
||||||
|
#include "mean.cl.h"
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
const std::string kernel_src = read_file("mean.cl");
|
||||||
|
#endif
|
||||||
|
cl_program prog =
|
||||||
|
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||||
|
|
||||||
|
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
|
||||||
|
|
||||||
|
CL_CHECK(clReleaseProgram(prog));
|
||||||
|
GGML_LOG_CONT(".");
|
||||||
|
}
|
||||||
|
|
||||||
// sub
|
// sub
|
||||||
{
|
{
|
||||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
|
@ -1825,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ssm_conv
|
||||||
|
{
|
||||||
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
const std::string kernel_src {
|
||||||
|
#include "ssm_conv.cl.h"
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
const std::string kernel_src = read_file("ssm_conv.cl");
|
||||||
|
#endif
|
||||||
|
cl_program prog =
|
||||||
|
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||||
|
|
||||||
|
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err));
|
||||||
|
CL_CHECK(clReleaseProgram(prog));
|
||||||
|
GGML_LOG_CONT(".");
|
||||||
|
}
|
||||||
|
|
||||||
// mul_mv_id_q4_0_f32_8x_flat
|
// mul_mv_id_q4_0_f32_8x_flat
|
||||||
{
|
{
|
||||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
|
@ -2959,6 +3041,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||||
case GGML_OP_ADD_ID:
|
case GGML_OP_ADD_ID:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_SQR:
|
||||||
|
case GGML_OP_SQRT:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
|
ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
|
|
@ -3007,6 +3093,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
|
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
|
||||||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||||
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
|
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
|
@ -3075,6 +3163,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
|
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
|
||||||
}
|
}
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MEAN:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
|
|
@ -5193,6 +5282,224 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0);
|
||||||
|
GGML_ASSERT(src0->extra);
|
||||||
|
GGML_ASSERT(dst);
|
||||||
|
GGML_ASSERT(dst->extra);
|
||||||
|
UNUSED(src1);
|
||||||
|
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
cl_kernel kernel;
|
||||||
|
|
||||||
|
// Currently assumes src0 is contiguous
|
||||||
|
int n = ggml_nelements(dst);
|
||||||
|
if (n % 4 == 0) {
|
||||||
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
kernel = backend_ctx->kernel_sqr_cont_f32_4;
|
||||||
|
} else {
|
||||||
|
kernel = backend_ctx->kernel_sqr_cont_f16_4;
|
||||||
|
}
|
||||||
|
n /= 4;
|
||||||
|
} else {
|
||||||
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
kernel = backend_ctx->kernel_sqr_cont_f32;
|
||||||
|
} else {
|
||||||
|
kernel = backend_ctx->kernel_sqr_cont_f16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||||
|
|
||||||
|
size_t global_work_size[] = {(size_t)n, 1, 1};
|
||||||
|
size_t local_work_size[] = {64, 1, 1};
|
||||||
|
|
||||||
|
size_t * local_work_size_ptr = local_work_size;
|
||||||
|
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
||||||
|
local_work_size_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0);
|
||||||
|
GGML_ASSERT(src0->extra);
|
||||||
|
GGML_ASSERT(dst);
|
||||||
|
GGML_ASSERT(dst->extra);
|
||||||
|
UNUSED(src1);
|
||||||
|
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
cl_kernel kernel;
|
||||||
|
|
||||||
|
// Currently assumes src0 is contiguous
|
||||||
|
int n = ggml_nelements(dst);
|
||||||
|
if (n % 4 == 0) {
|
||||||
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
kernel = backend_ctx->kernel_sqrt_cont_f32_4;
|
||||||
|
} else {
|
||||||
|
kernel = backend_ctx->kernel_sqrt_cont_f16_4;
|
||||||
|
}
|
||||||
|
n /= 4;
|
||||||
|
} else {
|
||||||
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
kernel = backend_ctx->kernel_sqrt_cont_f32;
|
||||||
|
} else {
|
||||||
|
kernel = backend_ctx->kernel_sqrt_cont_f16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||||
|
|
||||||
|
size_t global_work_size[] = {(size_t)n, 1, 1};
|
||||||
|
size_t local_work_size[] = {64, 1, 1};
|
||||||
|
|
||||||
|
size_t * local_work_size_ptr = local_work_size;
|
||||||
|
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
||||||
|
local_work_size_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0);
|
||||||
|
GGML_ASSERT(src0->extra);
|
||||||
|
GGML_ASSERT(dst);
|
||||||
|
GGML_ASSERT(dst->extra);
|
||||||
|
GGML_UNUSED(src1);
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
const int ne00 = src0->ne[0];
|
||||||
|
const int ne01 = src0->ne[1];
|
||||||
|
const int ne02 = src0->ne[2];
|
||||||
|
const int ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const cl_ulong nb01 = src0->nb[1];
|
||||||
|
const cl_ulong nb02 = src0->nb[2];
|
||||||
|
const cl_ulong nb03 = src0->nb[3];
|
||||||
|
|
||||||
|
const cl_ulong nb1 = dst->nb[1];
|
||||||
|
const cl_ulong nb2 = dst->nb[2];
|
||||||
|
const cl_ulong nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
cl_kernel kernel = backend_ctx->kernel_mean_f32;
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
|
||||||
|
|
||||||
|
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
|
||||||
|
size_t local_work_size[] = {(size_t)64, 1, 1};
|
||||||
|
|
||||||
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0);
|
||||||
|
GGML_ASSERT(src0->extra);
|
||||||
|
GGML_ASSERT(src1);
|
||||||
|
GGML_ASSERT(src1->extra);
|
||||||
|
GGML_ASSERT(dst);
|
||||||
|
GGML_ASSERT(dst->extra);
|
||||||
|
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||||
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
int ne01 = src0->ne[1];
|
||||||
|
cl_ulong nb00 = src0->nb[0];
|
||||||
|
cl_ulong nb01 = src0->nb[1];
|
||||||
|
cl_ulong nb02 = src0->nb[2];
|
||||||
|
|
||||||
|
int ne10 = src1->ne[0];
|
||||||
|
cl_ulong nb11 = src1->nb[1];
|
||||||
|
|
||||||
|
int ne1 = dst->ne[1];
|
||||||
|
int ne2 = dst->ne[2];
|
||||||
|
cl_ulong nb0 = dst->nb[0];
|
||||||
|
cl_ulong nb1 = dst->nb[1];
|
||||||
|
cl_ulong nb2 = dst->nb[2];
|
||||||
|
|
||||||
|
cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32;
|
||||||
|
|
||||||
|
if (ne10 % 4 == 0) {
|
||||||
|
kernel = backend_ctx->kernel_ssm_conv_f32_f32_4;
|
||||||
|
}
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
|
||||||
|
|
||||||
|
size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2};
|
||||||
|
size_t local_work_size[] = {64, 1, 1};
|
||||||
|
|
||||||
|
size_t * local_work_size_ptr = local_work_size;
|
||||||
|
if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
||||||
|
local_work_size_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(src0);
|
GGML_ASSERT(src0);
|
||||||
GGML_ASSERT(src0->extra);
|
GGML_ASSERT(src0->extra);
|
||||||
|
|
@ -9091,6 +9398,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||||
}
|
}
|
||||||
func = ggml_cl_sub;
|
func = ggml_cl_sub;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SQR:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cl_sqr;
|
||||||
|
break;
|
||||||
|
case GGML_OP_SQRT:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cl_sqrt;
|
||||||
|
break;
|
||||||
|
case GGML_OP_MEAN:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cl_mean;
|
||||||
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(tensor)) {
|
switch (ggml_get_unary_op(tensor)) {
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
|
|
@ -9192,6 +9517,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||||
}
|
}
|
||||||
func = ggml_cl_conv_2d;
|
func = ggml_cl_conv_2d;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cl_ssm_conv;
|
||||||
|
break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
if (!any_on_device) {
|
if (!any_on_device) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
|
||||||
|
kernel void kernel_mean_f32(
|
||||||
|
global float * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global float * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
int ne00,
|
||||||
|
int ne01,
|
||||||
|
int ne02,
|
||||||
|
int ne03,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2,
|
||||||
|
ulong nb3
|
||||||
|
) {
|
||||||
|
src0 = (global float *)((global char *)src0 + offset0);
|
||||||
|
dst = (global float *)((global char *)dst + offsetd);
|
||||||
|
|
||||||
|
int i3 = get_global_id(2);
|
||||||
|
int i2 = get_global_id(1);
|
||||||
|
int i1 = get_global_id(0);
|
||||||
|
|
||||||
|
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||||
|
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||||
|
|
||||||
|
float row_sum = 0;
|
||||||
|
|
||||||
|
for (int i0 = 0; i0 < ne00; i0++) {
|
||||||
|
row_sum += src_row[i0];
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_row[0] = row_sum / ne00;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||||
|
|
||||||
|
kernel void kernel_sqr_cont_f32(
|
||||||
|
global float * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global float * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global float*)((global char*)src0 + offset0);
|
||||||
|
dst = (global float*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = src0[gid] * src0[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqr_cont_f32_4(
|
||||||
|
global float4 * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global float4 * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global float4*)((global char*)src0 + offset0);
|
||||||
|
dst = (global float4*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = src0[gid] * src0[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqr_cont_f16(
|
||||||
|
global half * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global half * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global half*)((global char*)src0 + offset0);
|
||||||
|
dst = (global half*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = src0[gid] * src0[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqr_cont_f16_4(
|
||||||
|
global half4 * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global half4 * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global half4*)((global char*)src0 + offset0);
|
||||||
|
dst = (global half4*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = src0[gid] * src0[gid];
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||||
|
|
||||||
|
kernel void kernel_sqrt_cont_f32(
|
||||||
|
global float * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global float * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global float*)((global char*)src0 + offset0);
|
||||||
|
dst = (global float*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = sqrt(src0[gid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqrt_cont_f32_4(
|
||||||
|
global float4 * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global float4 * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global float4*)((global char*)src0 + offset0);
|
||||||
|
dst = (global float4*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = sqrt(src0[gid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqrt_cont_f16(
|
||||||
|
global half * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global half * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global half*)((global char*)src0 + offset0);
|
||||||
|
dst = (global half*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = convert_half(sqrt(convert_float(src0[gid])));
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqrt_cont_f16_4(
|
||||||
|
global half4 * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global half4 * dst,
|
||||||
|
ulong offsetd
|
||||||
|
) {
|
||||||
|
src0 = (global half4*)((global char*)src0 + offset0);
|
||||||
|
dst = (global half4*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
uint gid = get_global_id(0);
|
||||||
|
dst[gid] = convert_half4(sqrt(convert_float4(src0[gid])));
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
kernel void kernel_ssm_conv_f32_f32(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
ulong nb00,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb02,
|
||||||
|
int ne10,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb0,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2
|
||||||
|
){
|
||||||
|
src0 = src0 + offset0;
|
||||||
|
src1 = src1 + offset1;
|
||||||
|
dst = dst + offsetd;
|
||||||
|
|
||||||
|
int ir = get_global_id(0);
|
||||||
|
int i2 = get_global_id(1);
|
||||||
|
int i3 = get_global_id(2);
|
||||||
|
|
||||||
|
int nc = ne10;
|
||||||
|
|
||||||
|
global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
||||||
|
global float * c = (global float *) (src1 + ir*nb11);
|
||||||
|
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
|
sumf += s[i0] * c[i0];
|
||||||
|
}
|
||||||
|
|
||||||
|
d[0] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_ssm_conv_f32_f32_4(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
ulong nb00,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb02,
|
||||||
|
int ne10,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb0,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2
|
||||||
|
) {
|
||||||
|
src0 = src0 + offset0;
|
||||||
|
src1 = src1 + offset1;
|
||||||
|
dst = dst + offsetd;
|
||||||
|
|
||||||
|
int ir = get_global_id(0);
|
||||||
|
int i2 = get_global_id(1);
|
||||||
|
int i3 = get_global_id(2);
|
||||||
|
|
||||||
|
int nc = ne10;
|
||||||
|
|
||||||
|
global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
||||||
|
global float4 * c = (global float4 *) (src1 + ir*nb11);
|
||||||
|
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int i0 = 0; i0 < nc/4; ++i0) {
|
||||||
|
sumf += dot(s[i0], c[i0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
d[0] = sumf;
|
||||||
|
}
|
||||||
|
|
@ -106,6 +106,7 @@ enum rpc_cmd {
|
||||||
RPC_CMD_GET_ALLOC_SIZE,
|
RPC_CMD_GET_ALLOC_SIZE,
|
||||||
RPC_CMD_HELLO,
|
RPC_CMD_HELLO,
|
||||||
RPC_CMD_DEVICE_COUNT,
|
RPC_CMD_DEVICE_COUNT,
|
||||||
|
RPC_CMD_GRAPH_RECOMPUTE,
|
||||||
RPC_CMD_COUNT,
|
RPC_CMD_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
|
||||||
uint8_t result;
|
uint8_t result;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct rpc_msg_graph_compute_rsp {
|
|
||||||
uint8_t result;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct rpc_msg_get_device_memory_req {
|
struct rpc_msg_get_device_memory_req {
|
||||||
uint32_t device;
|
uint32_t device;
|
||||||
};
|
};
|
||||||
|
|
@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
|
||||||
uint64_t free_mem;
|
uint64_t free_mem;
|
||||||
uint64_t total_mem;
|
uint64_t total_mem;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_graph_recompute_req {
|
||||||
|
uint32_t device;
|
||||||
|
};
|
||||||
|
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
// RPC data structures
|
// RPC data structures
|
||||||
|
|
@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
|
||||||
size_t max_size;
|
size_t max_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct graph_cache {
|
||||||
|
|
||||||
|
bool is_cached(const ggml_cgraph * cgraph) {
|
||||||
|
if ((int)last_graph.size() != cgraph->n_nodes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(const ggml_cgraph * cgraph) {
|
||||||
|
last_graph.resize(cgraph->n_nodes);
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ggml_tensor> last_graph;
|
||||||
|
};
|
||||||
|
|
||||||
struct ggml_backend_rpc_context {
|
struct ggml_backend_rpc_context {
|
||||||
std::string endpoint;
|
std::string endpoint;
|
||||||
uint32_t device;
|
uint32_t device;
|
||||||
std::string name;
|
std::string name;
|
||||||
|
graph_cache gc;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_rpc_buffer_context {
|
struct ggml_backend_rpc_buffer_context {
|
||||||
|
|
@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||||
std::vector<uint8_t> input;
|
|
||||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||||
rpc_msg_graph_compute_rsp response;
|
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
||||||
auto sock = get_socket(rpc_ctx->endpoint);
|
if (reuse) {
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
rpc_msg_graph_recompute_req request;
|
||||||
RPC_STATUS_ASSERT(status);
|
request.device = rpc_ctx->device;
|
||||||
return (enum ggml_status)response.result;
|
auto sock = get_socket(rpc_ctx->endpoint);
|
||||||
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
||||||
|
RPC_STATUS_ASSERT(status);
|
||||||
|
} else {
|
||||||
|
rpc_ctx->gc.add(cgraph);
|
||||||
|
std::vector<uint8_t> input;
|
||||||
|
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||||
|
auto sock = get_socket(rpc_ctx->endpoint);
|
||||||
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
|
||||||
|
RPC_STATUS_ASSERT(status);
|
||||||
|
}
|
||||||
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_i ggml_backend_rpc_interface = {
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
||||||
|
|
@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
||||||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||||
/* .endpoint = */ endpoint,
|
/* .endpoint = */ endpoint,
|
||||||
/* .device = */ device,
|
/* .device = */ device,
|
||||||
/* .name = */ dev_name
|
/* .name = */ dev_name,
|
||||||
|
/* .gc = */ {},
|
||||||
};
|
};
|
||||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||||
ggml_backend_t backend = new ggml_backend {
|
ggml_backend_t backend = new ggml_backend {
|
||||||
|
|
@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
|
||||||
|
|
||||||
class rpc_server {
|
class rpc_server {
|
||||||
public:
|
public:
|
||||||
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
|
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
|
||||||
: backends(std::move(backends)), cache_dir(cache_dir) {
|
: backends(std::move(all_backends)), cache_dir(cache_dir) {
|
||||||
|
stored_graphs.resize(backends.size());
|
||||||
}
|
}
|
||||||
~rpc_server();
|
~rpc_server();
|
||||||
|
|
||||||
|
|
@ -936,11 +976,17 @@ public:
|
||||||
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
||||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
bool graph_compute(const std::vector<uint8_t> & input);
|
||||||
|
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
|
||||||
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
||||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||||
|
|
||||||
|
struct stored_graph {
|
||||||
|
ggml_context_ptr ctx_ptr;
|
||||||
|
ggml_cgraph * graph;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
||||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||||
|
|
@ -953,6 +999,8 @@ private:
|
||||||
std::vector<ggml_backend_t> backends;
|
std::vector<ggml_backend_t> backends;
|
||||||
const char * cache_dir;
|
const char * cache_dir;
|
||||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||||
|
// store the last computed graph for each backend
|
||||||
|
std::vector<stored_graph> stored_graphs;
|
||||||
};
|
};
|
||||||
|
|
||||||
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
||||||
|
|
@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||||
// serialization format:
|
// serialization format:
|
||||||
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||||
if (input.size() < 2*sizeof(uint32_t)) {
|
if (input.size() < 2*sizeof(uint32_t)) {
|
||||||
|
|
@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||||
response.result = status;
|
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||||
|
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||||
|
stored_graphs[device].graph = graph;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
|
||||||
|
uint32_t device = request.device;
|
||||||
|
if (device >= backends.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (stored_graphs[device].graph == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_cgraph * graph = stored_graphs[device].graph;
|
||||||
|
LOG_DBG("[%s] device: %u\n", __func__, device);
|
||||||
|
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||||
|
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
||||||
if (!recv_msg(sockfd, input)) {
|
if (!recv_msg(sockfd, input)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
rpc_msg_graph_compute_rsp response;
|
if (!server.graph_compute(input)) {
|
||||||
if (!server.graph_compute(input, response)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
break;
|
||||||
|
}
|
||||||
|
case RPC_CMD_GRAPH_RECOMPUTE: {
|
||||||
|
rpc_msg_graph_recompute_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!server.graph_recompute(request)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
|
||||||
return dpct::pow(base, exph);
|
return dpct::pow(base, exph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const sycl::uint3 init_fastdiv_values(uint32_t d) {
|
||||||
|
GGML_ASSERT(d != 0);
|
||||||
|
|
||||||
|
uint32_t L = 0;
|
||||||
|
while (L < 32 && (uint32_t{ 1 } << L) < d) {
|
||||||
|
L++;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
|
||||||
|
return sycl::uint3(mp, L, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
|
||||||
|
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
|
||||||
|
return (hi + n) >> fastdiv_values.y();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
|
||||||
|
const uint32_t div_val = fastdiv(n, fastdiv_values);
|
||||||
|
const uint32_t mod_val = n - div_val * fastdiv_values.z();
|
||||||
|
return sycl::uint2(div_val, mod_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // GGML_SYCL_COMMON_HPP
|
#endif // GGML_SYCL_COMMON_HPP
|
||||||
|
|
|
||||||
|
|
@ -1,72 +1,100 @@
|
||||||
#include "pad_reflect_1d.hpp"
|
#include "pad_reflect_1d.hpp"
|
||||||
|
|
||||||
void pad_reflect_1d_f32(const float* src,float* dst,
|
static void pad_reflect_1d_kernel_f32(
|
||||||
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
|
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
|
||||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
|
||||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
const int64_t ne03, const int64_t nb00, const int64_t nb01,
|
||||||
const sycl::nd_item<3> &item_ct1){
|
const int64_t nb02, const int64_t nb03, const int64_t nb0,
|
||||||
|
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
|
||||||
|
const int p1, sycl::nd_item<3> item_ct1) {
|
||||||
|
|
||||||
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
|
const int64_t i3 = item_ct1.get_group(0);
|
||||||
const int i1 = item_ct1.get_group(1);
|
const int64_t i2 = item_ct1.get_group(1);
|
||||||
const int g2 = item_ct1.get_group(2);
|
|
||||||
const int i2 = g2 % ne02;
|
|
||||||
const int i3 = g2 / ne02;
|
|
||||||
|
|
||||||
if (i0 >= p0 + ne0 + p1) return;
|
const sycl::uint2 div_mod_packed =
|
||||||
|
fast_div_modulo(item_ct1.get_group(2), ne01);
|
||||||
|
const int64_t tile1 = div_mod_packed.y();
|
||||||
|
const int64_t tile0 = div_mod_packed.x();
|
||||||
|
const int64_t i1 = tile1;
|
||||||
|
const int64_t i0 =
|
||||||
|
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
|
||||||
|
|
||||||
int t = i0 - p0;
|
if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
|
||||||
int period = 2 * ne0 -2;
|
return;
|
||||||
int m = t % period;
|
}
|
||||||
m += (m < 0) * period;
|
|
||||||
int center = ne0 -1;
|
|
||||||
int srci0 = center - abs(center - m);
|
|
||||||
|
|
||||||
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
|
const char *src0_ptr =
|
||||||
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
|
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
|
||||||
dst[offest_dst] = src[offest_src];
|
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
||||||
|
|
||||||
|
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
|
||||||
|
int64_t src_idx;
|
||||||
|
|
||||||
|
if (rel_i0 < 0) {
|
||||||
|
// Left padding - reflect
|
||||||
|
src_idx = -rel_i0;
|
||||||
|
} else if (rel_i0 < ne00) {
|
||||||
|
// Middle - copy
|
||||||
|
src_idx = rel_i0;
|
||||||
|
} else {
|
||||||
|
// Right padding - reflect
|
||||||
|
src_idx = 2 * ne00 - 2 - rel_i0;
|
||||||
|
}
|
||||||
|
const float value = *(const float *)(src0_ptr + src_idx * nb00);
|
||||||
|
*(float *)(dst_ptr + i0 * nb0) = value;
|
||||||
|
|
||||||
|
GGML_UNUSED(p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
|
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
|
||||||
|
ggml_tensor *dst) {
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor *src0 = dst->src[0];
|
||||||
queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
const int32_t *opts = (const int32_t *)dst->op_params;
|
||||||
const int p0 = opts[0];
|
const int p0 = opts[0];
|
||||||
const int p1 = opts[1];
|
const int p1 = opts[1];
|
||||||
|
|
||||||
const int64_t ne0 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
const int64_t ne00 = dst->ne[0];
|
const int64_t ne0 = dst->ne[0];
|
||||||
const int64_t ne01 = dst->ne[1];
|
|
||||||
const int64_t ne02 = dst->ne[2];
|
|
||||||
const int64_t ne03 = dst->ne[3];
|
|
||||||
|
|
||||||
const int64_t nb00 = dst->nb[0];
|
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
||||||
const int64_t nb01 = dst->nb[1];
|
|
||||||
const int64_t nb02 = dst->nb[2];
|
|
||||||
const int64_t nb03 = dst->nb[3];
|
|
||||||
const int64_t nb0 = src0->nb[0];
|
|
||||||
const int64_t nb1 = src0->nb[1];
|
|
||||||
const int64_t nb2 = src0->nb[2];
|
|
||||||
const int64_t nb3 = src0->nb[3];
|
|
||||||
|
|
||||||
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
|
constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
|
||||||
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
|
const int64_t tiles0 = (ne0 + bx - 1) / bx;
|
||||||
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
|
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
|
||||||
|
(unsigned)ne03);
|
||||||
|
const dpct::dim3 block_dims((unsigned)bx, 1, 1);
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::nd_range<3>(global,
|
auto src0_data_ct0 = src0->data;
|
||||||
local),
|
auto dst_data_ct1 = dst->data;
|
||||||
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
|
auto src0_nb_ct7 = src0->nb[0];
|
||||||
(const float *) src0->data, (float *) dst->data,
|
auto src0_nb_ct8 = src0->nb[1];
|
||||||
ne0, ne02, p0, p1,
|
auto src0_nb_ct9 = src0->nb[2];
|
||||||
nb0, nb1, nb2, nb3,
|
auto src0_nb_ct10 = src0->nb[3];
|
||||||
nb00, nb01, nb02, nb03
|
auto dst_nb_ct11 = dst->nb[0];
|
||||||
, item_ct1);
|
auto dst_nb_ct12 = dst->nb[1];
|
||||||
});
|
auto dst_nb_ct13 = dst->nb[2];
|
||||||
|
auto dst_nb_ct14 = dst->nb[3];
|
||||||
|
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
pad_reflect_1d_kernel_f32(
|
||||||
|
src0_data_ct0, dst_data_ct1, ne0, ne00,
|
||||||
|
ne01_packed, ne02, ne03, src0_nb_ct7,
|
||||||
|
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
|
||||||
|
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
|
||||||
|
dst_nb_ct14, p0, p1, item_ct1);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
|
|
||||||
|
#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256
|
||||||
|
|
||||||
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||||
|
|
||||||
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
|
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
|
||||||
|
|
|
||||||
|
|
@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_solve_tri_pipeline_state {
|
||||||
|
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
|
||||||
|
: N(N), K(K) {}
|
||||||
|
|
||||||
|
uint32_t N, K;
|
||||||
|
|
||||||
|
bool operator<(const vk_solve_tri_pipeline_state &b) const {
|
||||||
|
return std::tie(N, K) <
|
||||||
|
std::tie(b.N, b.K);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
enum shader_reduction_mode {
|
enum shader_reduction_mode {
|
||||||
SHADER_REDUCTION_MODE_SHMEM,
|
SHADER_REDUCTION_MODE_SHMEM,
|
||||||
SHADER_REDUCTION_MODE_HYBRID,
|
SHADER_REDUCTION_MODE_HYBRID,
|
||||||
|
|
@ -409,6 +421,7 @@ enum shader_reduction_mode {
|
||||||
// argsort pipelines for up to 1<<10 invocations per workgroup
|
// argsort pipelines for up to 1<<10 invocations per workgroup
|
||||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
static constexpr uint32_t num_topk_pipelines = 11;
|
||||||
|
|
||||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
|
@ -515,6 +528,7 @@ struct vk_device_struct {
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
bool support_async;
|
bool support_async;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
|
uint32_t subgroup_size_log2;
|
||||||
uint32_t shader_core_count;
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
bool prefer_host_memory;
|
bool prefer_host_memory;
|
||||||
|
|
@ -635,6 +649,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_sin_f32;
|
vk_pipeline pipeline_sin_f32;
|
||||||
vk_pipeline pipeline_cos_f32;
|
vk_pipeline pipeline_cos_f32;
|
||||||
vk_pipeline pipeline_log[2];
|
vk_pipeline pipeline_log[2];
|
||||||
|
vk_pipeline pipeline_tri[2];
|
||||||
vk_pipeline pipeline_clamp_f32;
|
vk_pipeline pipeline_clamp_f32;
|
||||||
vk_pipeline pipeline_pad_f32;
|
vk_pipeline pipeline_pad_f32;
|
||||||
vk_pipeline pipeline_roll_f32;
|
vk_pipeline pipeline_roll_f32;
|
||||||
|
|
@ -704,10 +719,12 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||||
|
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
vk_pipeline pipeline_cumsum_f32;
|
vk_pipeline pipeline_cumsum_f32;
|
||||||
vk_pipeline pipeline_argmax_f32;
|
vk_pipeline pipeline_argmax_f32;
|
||||||
vk_pipeline pipeline_count_equal_i32;
|
vk_pipeline pipeline_count_equal_i32;
|
||||||
|
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
|
||||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||||
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
|
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
|
||||||
vk_pipeline pipeline_timestep_embedding_f32;
|
vk_pipeline pipeline_timestep_embedding_f32;
|
||||||
|
|
@ -1205,6 +1222,15 @@ struct vk_op_argsort_push_constants {
|
||||||
uint32_t inner_end;
|
uint32_t inner_end;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_topk_push_constants {
|
||||||
|
uint32_t orig_ncols;
|
||||||
|
uint32_t ncols_input;
|
||||||
|
uint32_t ncols_output;
|
||||||
|
uint32_t nrows;
|
||||||
|
uint32_t first_pass;
|
||||||
|
uint32_t last_pass;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_im2col_push_constants {
|
struct vk_op_im2col_push_constants {
|
||||||
uint64_t dst_addr;
|
uint64_t dst_addr;
|
||||||
uint32_t batch_offset; uint32_t offset_delta;
|
uint32_t batch_offset; uint32_t offset_delta;
|
||||||
|
|
@ -3851,6 +3877,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
@ -3965,6 +3994,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
|
||||||
|
const uint32_t BLOCK_SIZE = 1u << i;
|
||||||
|
const uint32_t NCOLS_PADDED_LOG2 = i;
|
||||||
|
if (i <= device->max_workgroup_size_log2) {
|
||||||
|
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
||||||
|
sizeof(int) * device->subgroup_size +
|
||||||
|
2 * sizeof(int) +
|
||||||
|
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||||
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
||||||
|
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
||||||
|
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
@ -3973,6 +4019,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
|
for (auto &s : device->pipeline_solve_tri_f32) {
|
||||||
|
const vk_solve_tri_pipeline_state &state = s.first;
|
||||||
|
ggml_vk_create_pipeline(
|
||||||
|
device, s.second, "solve_tri_f32",
|
||||||
|
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
|
||||||
|
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
|
||||||
|
}
|
||||||
|
|
||||||
#define IM2COL(bda) \
|
#define IM2COL(bda) \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||||
|
|
@ -4336,6 +4390,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
||||||
|
|
||||||
device->subgroup_size = subgroup_props.subgroupSize;
|
device->subgroup_size = subgroup_props.subgroupSize;
|
||||||
|
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
|
||||||
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
if (sm_builtins) {
|
if (sm_builtins) {
|
||||||
device->shader_core_count = sm_props.shaderSMCount;
|
device->shader_core_count = sm_props.shaderSMCount;
|
||||||
|
|
@ -5259,7 +5314,8 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||||
ctx->prealloc_size_x = 0;
|
ctx->prealloc_size_x = 0;
|
||||||
ctx->prealloc_size_y = 0;
|
ctx->prealloc_size_y = 0;
|
||||||
ctx->prealloc_size_split_k = 0;
|
ctx->prealloc_size_split_k = 0;
|
||||||
ctx->prealloc_size_add_rms_partials = 0;
|
// Fixed size of 1KB, for deterministic behavior
|
||||||
|
ctx->prealloc_size_add_rms_partials = 1024;
|
||||||
|
|
||||||
ctx->fence = ctx->device->device.createFence({});
|
ctx->fence = ctx->device->device.createFence({});
|
||||||
ctx->almost_ready_fence = ctx->device->device.createFence({});
|
ctx->almost_ready_fence = ctx->device->device.createFence({});
|
||||||
|
|
@ -8238,6 +8294,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_TRI:
|
||||||
|
if (src0->type == dst->type &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||||
|
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_clamp_f32;
|
return ctx->device->pipeline_clamp_f32;
|
||||||
|
|
@ -8465,6 +8527,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_cumsum_f32;
|
return ctx->device->pipeline_cumsum_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
|
||||||
|
vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
|
||||||
|
|
||||||
|
vk_pipeline pipeline = nullptr;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
|
||||||
|
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
|
||||||
|
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
|
||||||
|
pipeline = it->second;
|
||||||
|
} else {
|
||||||
|
ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pipeline;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||||
return ctx->device->pipeline_argmax_f32;
|
return ctx->device->pipeline_argmax_f32;
|
||||||
|
|
@ -8656,41 +8738,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
GGML_UNUSED(src2);
|
GGML_UNUSED(src2);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
||||||
switch (op) {
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
case GGML_OP_GET_ROWS:
|
|
||||||
case GGML_OP_ADD:
|
|
||||||
case GGML_OP_SUB:
|
|
||||||
case GGML_OP_MUL:
|
|
||||||
case GGML_OP_DIV:
|
|
||||||
case GGML_OP_ADD_ID:
|
|
||||||
case GGML_OP_CONCAT:
|
|
||||||
case GGML_OP_UPSCALE:
|
|
||||||
case GGML_OP_SQR:
|
|
||||||
case GGML_OP_SQRT:
|
|
||||||
case GGML_OP_SIN:
|
|
||||||
case GGML_OP_COS:
|
|
||||||
case GGML_OP_LOG:
|
|
||||||
case GGML_OP_CLAMP:
|
|
||||||
case GGML_OP_PAD:
|
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_REPEAT_BACK:
|
|
||||||
case GGML_OP_ROPE:
|
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_CONV_2D_DW:
|
|
||||||
case GGML_OP_IM2COL:
|
|
||||||
case GGML_OP_IM2COL_3D:
|
|
||||||
case GGML_OP_SET_ROWS:
|
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_SUM_ROWS:
|
|
||||||
case GGML_OP_MEAN:
|
|
||||||
return true;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
||||||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||||
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||||
|
|
@ -8775,7 +8822,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
||||||
std::cerr << "), " << ggml_op_name(op) << ")");
|
std::cerr << "), " << ggml_op_name(op) << ")");
|
||||||
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
|
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
|
||||||
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
|
|
||||||
GGML_ASSERT(dst->buffer != nullptr);
|
GGML_ASSERT(dst->buffer != nullptr);
|
||||||
const uint64_t ne00 = src0->ne[0];
|
const uint64_t ne00 = src0->ne[0];
|
||||||
const uint64_t ne01 = src0->ne[1];
|
const uint64_t ne01 = src0->ne[1];
|
||||||
|
|
@ -8806,22 +8852,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
|
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
|
||||||
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
|
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
|
||||||
|
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
|
||||||
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
|
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
|
||||||
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
|
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
|
||||||
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
|
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
|
||||||
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
|
|
||||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
|
|
||||||
|
|
||||||
// Compute misalignment offset for descriptors and store it in in push constants.
|
// Compute misalignment offset for descriptors and store it in in push constants.
|
||||||
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
|
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
|
||||||
|
|
||||||
std::array<uint32_t, 3> elements;
|
std::array<uint32_t, 3> elements;
|
||||||
|
|
||||||
// Single call if dimension 2 is contiguous
|
|
||||||
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
|
|
||||||
|
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
|
@ -8842,6 +8883,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
elements = { nr, 1, 1 };
|
elements = { nr, 1, 1 };
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
{
|
||||||
|
uint32_t nr = (uint32_t)(ne02 * ne03);
|
||||||
|
if (nr > 262144) {
|
||||||
|
elements = { 512, 512, CEIL_DIV(nr, 262144) };
|
||||||
|
} else if (nr > 512) {
|
||||||
|
elements = { 512, CEIL_DIV(nr, 512), 1 };
|
||||||
|
} else {
|
||||||
|
elements = { nr, 1, 1 };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
if (ctx->do_add_rms_partials) {
|
if (ctx->do_add_rms_partials) {
|
||||||
// Run one element per thread, 128 threads per workgroup
|
// Run one element per thread, 128 threads per workgroup
|
||||||
|
|
@ -8948,6 +9001,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
case GGML_OP_SIN:
|
case GGML_OP_SIN:
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
|
case GGML_OP_TRI:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ROLL:
|
case GGML_OP_ROLL:
|
||||||
|
|
@ -9628,6 +9682,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||||
|
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||||
|
|
||||||
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||||
|
|
@ -10143,6 +10204,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
uint32_t ncols = src0->ne[0];
|
||||||
|
uint32_t nrows = ggml_nrows(src0);
|
||||||
|
uint32_t k = dst->ne[0];
|
||||||
|
|
||||||
|
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
||||||
|
|
||||||
|
// Reserve space for ivec2 per element, double buffered
|
||||||
|
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
||||||
|
const size_t x_sz = dbl_buf_size * 2;
|
||||||
|
uint32_t dbl_buf_index = 0;
|
||||||
|
|
||||||
|
if (ctx->prealloc_size_x < x_sz) {
|
||||||
|
ctx->prealloc_size_x = x_sz;
|
||||||
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
if (ctx->prealloc_x_need_sync) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements;
|
||||||
|
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
|
elements[2] = 1;
|
||||||
|
|
||||||
|
uint32_t num_elements = ncols;
|
||||||
|
|
||||||
|
// Each iteration reduces a workgroup's worth of elements down to the K
|
||||||
|
// largest elements. Repeat until we have the top K elements.
|
||||||
|
// Need to do at least one iteration to write out the results.
|
||||||
|
bool done_one_iter = false;
|
||||||
|
while (num_elements > k || !done_one_iter) {
|
||||||
|
done_one_iter = true;
|
||||||
|
|
||||||
|
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||||
|
// But if K is larger, then we need a larger workgroup
|
||||||
|
uint32_t max_pipeline = num_topk_pipelines - 3;
|
||||||
|
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
|
||||||
|
// require full subgroup
|
||||||
|
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
|
||||||
|
|
||||||
|
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
|
||||||
|
pipeline_idx = std::min(pipeline_idx, max_pipeline);
|
||||||
|
pipeline_idx = std::max(pipeline_idx, min_pipeline);
|
||||||
|
|
||||||
|
if (num_elements > (1u << pipeline_idx)) {
|
||||||
|
// If we could finish on this loop iteration (i.e. a single workgroup)
|
||||||
|
// then do so. It's better than the overhead of another pass.
|
||||||
|
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
|
||||||
|
if (num_elements <= (1u << i)) {
|
||||||
|
pipeline_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
|
||||||
|
// If the device doesn't support a pipeline this large, use smaller
|
||||||
|
while (!pipeline) {
|
||||||
|
pipeline_idx--;
|
||||||
|
GGML_ASSERT(pipeline_idx >= min_pipeline);
|
||||||
|
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_op_topk_push_constants pc2 = pc;
|
||||||
|
pc2.ncols_input = num_elements;
|
||||||
|
|
||||||
|
// Number of elements remaining after this pass
|
||||||
|
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
||||||
|
|
||||||
|
vk_subbuffer src_buf;
|
||||||
|
vk_subbuffer dst_buf;
|
||||||
|
|
||||||
|
if (num_elements == ncols) {
|
||||||
|
pc2.first_pass = 1;
|
||||||
|
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||||
|
} else {
|
||||||
|
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
|
||||||
|
}
|
||||||
|
if (num_dst_elements == k) {
|
||||||
|
pc2.last_pass = 1;
|
||||||
|
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||||
|
} else {
|
||||||
|
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[0] = num_elements;
|
||||||
|
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
|
||||||
|
num_elements = num_dst_elements;
|
||||||
|
dbl_buf_index ^= 1;
|
||||||
|
if (num_elements > k) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx->prealloc_x_need_sync = true;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
|
||||||
|
|
@ -10172,6 +10331,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
|
||||||
|
(uint32_t)ggml_nelements(src0),
|
||||||
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||||
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
|
0,
|
||||||
|
0.0f, 0.0f, 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
const int32_t s0 = dst->op_params[0];
|
const int32_t s0 = dst->op_params[0];
|
||||||
const int32_t s1 = dst->op_params[1];
|
const int32_t s1 = dst->op_params[1];
|
||||||
|
|
@ -11638,6 +11812,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
ggml_vk_log(ctx, compute_ctx, src0, node);
|
ggml_vk_log(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_TRI:
|
||||||
|
ggml_vk_tri(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
ggml_vk_clamp(ctx, compute_ctx, src0, node);
|
ggml_vk_clamp(ctx, compute_ctx, src0, node);
|
||||||
|
|
@ -11755,6 +11933,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
ggml_vk_topk(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
ggml_vk_sum(ctx, compute_ctx, src0, node);
|
ggml_vk_sum(ctx, compute_ctx, src0, node);
|
||||||
|
|
@ -11779,6 +11961,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_COUNT_EQUAL:
|
case GGML_OP_COUNT_EQUAL:
|
||||||
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
|
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
|
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
|
||||||
|
|
@ -12963,7 +13149,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
ctx->fused_ops_write_mask = 0;
|
ctx->fused_ops_write_mask = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->prealloc_size_add_rms_partials = std::max(ctx->prealloc_size_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
|
|
||||||
ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
|
ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
|
||||||
|
|
||||||
if (vk_perf_logger_enabled) {
|
if (vk_perf_logger_enabled) {
|
||||||
|
|
@ -13026,24 +13211,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This function tries to reorder the graph to allow nodes to run in parallel.
|
|
||||||
// This helps with small batches, but for large batches its a slowdown, probably
|
|
||||||
// due to cache contention. So only reorder if the majority of nodes have few rows.
|
|
||||||
int num_small_nodes = 0;
|
|
||||||
int num_counted_nodes = 0;
|
|
||||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
|
||||||
if (!is_empty(graph->nodes[i]) &&
|
|
||||||
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
|
|
||||||
if (ggml_nrows(graph->nodes[i]) <= 8) {
|
|
||||||
num_small_nodes++;
|
|
||||||
}
|
|
||||||
num_counted_nodes++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (num_small_nodes < num_counted_nodes / 2) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ggml_tensor *> new_order;
|
std::vector<ggml_tensor *> new_order;
|
||||||
std::vector<bool> used(graph->n_nodes, false);
|
std::vector<bool> used(graph->n_nodes, false);
|
||||||
std::set<ggml_tensor *> used_node_set;
|
std::set<ggml_tensor *> used_node_set;
|
||||||
|
|
@ -13762,17 +13929,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
op->type == GGML_TYPE_F32;
|
op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SILU_BACK:
|
case GGML_OP_SILU_BACK:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SQRT:
|
case GGML_OP_SQRT:
|
||||||
case GGML_OP_SIN:
|
case GGML_OP_SIN:
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
case GGML_OP_TRI:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
|
op->type == op->src[0]->type;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||||
|
|
@ -13787,19 +13958,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
|
// We could potentially support larger, using argsort to sort the
|
||||||
|
// whole thing. Not clear if this is needed.
|
||||||
|
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
|
||||||
|
if (min_pipeline >= num_topk_pipelines ||
|
||||||
|
!device->pipeline_topk_f32[min_pipeline]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
|
||||||
|
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
|
||||||
|
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_FILL:
|
case GGML_OP_FILL:
|
||||||
|
return op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ROLL:
|
case GGML_OP_ROLL:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
|
||||||
|
&& (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
return true;
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
|
||||||
|
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
|
|
@ -13813,16 +14012,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
{
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||||
|
|
||||||
|
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const uint32_t N = op->src[0]->ne[0];
|
||||||
|
const uint32_t K = op->src[1]->ne[0];
|
||||||
|
// K dimension limited to workgroup size
|
||||||
|
if (K > 128) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_COUNT_EQUAL:
|
case GGML_OP_COUNT_EQUAL:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
|
||||||
|
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
|
return ggml_is_contiguous(op->src[1])
|
||||||
|
&& op->src[1]->type == GGML_TYPE_F32
|
||||||
|
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
||||||
case GGML_OP_IM2COL_3D:
|
case GGML_OP_IM2COL_3D:
|
||||||
|
return op->src[1]->type == GGML_TYPE_F32
|
||||||
|
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
|
||||||
|
&& op->src[1]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
return true;
|
return true; // all inputs are contiguous, see ggml.c
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < 6; i++) {
|
||||||
|
|
@ -13863,7 +14093,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
return true;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
|
|
@ -14304,6 +14534,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_LOG) {
|
} else if (tensor->op == GGML_OP_LOG) {
|
||||||
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
|
||||||
|
} else if (tensor->op == GGML_OP_TRI) {
|
||||||
|
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
|
||||||
} else if (tensor->op == GGML_OP_CLAMP) {
|
} else if (tensor->op == GGML_OP_CLAMP) {
|
||||||
const float * params = (const float *)tensor->op_params;
|
const float * params = (const float *)tensor->op_params;
|
||||||
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
||||||
|
|
@ -14459,6 +14691,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
} else if (tensor->op == GGML_OP_ARGSORT) {
|
} else if (tensor->op == GGML_OP_ARGSORT) {
|
||||||
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
||||||
|
} else if (tensor->op == GGML_OP_TOP_K) {
|
||||||
|
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
|
||||||
} else if (tensor->op == GGML_OP_SUM) {
|
} else if (tensor->op == GGML_OP_SUM) {
|
||||||
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
||||||
|
|
@ -14471,6 +14705,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
|
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
|
||||||
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
|
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
|
} else if (tensor->op == GGML_OP_SOLVE_TRI) {
|
||||||
|
tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
|
||||||
} else if (tensor->op == GGML_OP_IM2COL) {
|
} else if (tensor->op == GGML_OP_IM2COL) {
|
||||||
const int32_t s0 = tensor->op_params[0];
|
const int32_t s0 = tensor->op_params[0];
|
||||||
const int32_t s1 = tensor->op_params[1];
|
const int32_t s1 = tensor->op_params[1];
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
#include "generic_binary_head.glsl"
|
||||||
|
|
||||||
|
layout (constant_id = 1) const uint N = 64;
|
||||||
|
layout (constant_id = 2) const uint K = 32;
|
||||||
|
|
||||||
|
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
uint a_base, b_base, x_base;
|
||||||
|
|
||||||
|
FLOAT_TYPE get_a(uint r, uint c) {
|
||||||
|
return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE get_b(uint r, uint c) {
|
||||||
|
return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void store_x(uint r, uint c, FLOAT_TYPE v) {
|
||||||
|
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
shared FLOAT_TYPE shA[N * N];
|
||||||
|
shared FLOAT_TYPE shB[N * K];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
if (batch >= p.ne02 * p.ne03) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i3 = batch / p.ne22;
|
||||||
|
const uint i2 = batch % p.ne22;
|
||||||
|
a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;
|
||||||
|
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
|
||||||
|
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
|
||||||
|
|
||||||
|
// Load the A matrix into shA
|
||||||
|
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
|
||||||
|
uint idx = i + tid;
|
||||||
|
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
|
||||||
|
shA[idx] = get_a(idx / N, idx % N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Load the B matrix into shB
|
||||||
|
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
|
||||||
|
uint idx = i + tid;
|
||||||
|
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
|
||||||
|
shB[idx] = get_b(idx / K, idx % K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
FLOAT_TYPE X[N];
|
||||||
|
// Each thread solves one column
|
||||||
|
if (tid < K) {
|
||||||
|
[[unroll]] for (int r = 0; r < N; ++r) {
|
||||||
|
FLOAT_TYPE b = shB[r * K + tid];
|
||||||
|
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
|
||||||
|
[[unroll]] for (int c = 0; c < r; ++c) {
|
||||||
|
b -= shA[r * N + c] * X[c];
|
||||||
|
}
|
||||||
|
FLOAT_TYPE x = b / shA[r * N + r];
|
||||||
|
X[r] = x;
|
||||||
|
store_x(r, tid, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
// Input can either be the source (A) or intermediate values (S).
|
||||||
|
// Similarly, output can be either destination (D) or intermediate values (S).
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {int data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint orig_ncols;
|
||||||
|
uint ncols_input;
|
||||||
|
uint ncols_output;
|
||||||
|
uint nrows;
|
||||||
|
uint first_pass;
|
||||||
|
uint last_pass;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
// pairs of (gid, value)
|
||||||
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void topk(bool needs_bounds_check, const uint row) {
|
||||||
|
const int col = int(gl_LocalInvocationID.x);
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
|
if (p.first_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_input;
|
||||||
|
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols;
|
||||||
|
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst_row[col] = ivec2(p.orig_ncols, 0);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.ncols_output == 1) {
|
||||||
|
// Fast path for single output - just do a max reduction
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
|
if (col < s) {
|
||||||
|
ivec2 a = dst_row[col];
|
||||||
|
ivec2 b = dst_row[col + s];
|
||||||
|
if (a.x >= p.orig_ncols ||
|
||||||
|
b.x < p.orig_ncols && b.y > a.y) {
|
||||||
|
dst_row[col] = b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// bitonic sort on this group of elements
|
||||||
|
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
|
||||||
|
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
||||||
|
uint num_inner_loop_iters = outer_idx + 1;
|
||||||
|
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
||||||
|
const int ixj = int(col ^ j);
|
||||||
|
|
||||||
|
int idx_0 = (col & k) == 0 ? col : ixj;
|
||||||
|
int idx_1 = (col & k) == 0 ? ixj : col;
|
||||||
|
|
||||||
|
ivec2 sh_idx_0 = dst_row[idx_0];
|
||||||
|
ivec2 sh_idx_1 = dst_row[idx_1];
|
||||||
|
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
|
||||||
|
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
|
||||||
|
|
||||||
|
if ((idx_0_oob ||
|
||||||
|
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
|
||||||
|
dst_row[idx_0] = sh_idx_1;
|
||||||
|
dst_row[idx_1] = sh_idx_0;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||||
|
if (p.last_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_output;
|
||||||
|
data_d[row_offset + col] = dst_row[col].x;
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||||
|
data_t[row_offset + col] = dst_row[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
// Fast path for fully occupied workgroups
|
||||||
|
if ((p.ncols_input % BLOCK_SIZE) == 0) {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(false, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(true, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,199 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_debug_printf : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
|
||||||
|
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
// Input can either be the source (A) or intermediate values (S).
|
||||||
|
// Similarly, output can be either destination (D) or intermediate values (S).
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {int data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint orig_ncols;
|
||||||
|
uint ncols_input;
|
||||||
|
uint ncols_output;
|
||||||
|
uint nrows;
|
||||||
|
uint first_pass;
|
||||||
|
uint last_pass;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
// pairs of (gid, value)
|
||||||
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
|
|
||||||
|
shared int counts[SUBGROUP_SIZE];
|
||||||
|
shared int sh_min_idx;
|
||||||
|
shared uint sh_total;
|
||||||
|
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
|
||||||
|
// Map float values to uint such that comparisons still work.
|
||||||
|
// Positive values set the high bit, negative values are inverted.
|
||||||
|
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
|
||||||
|
uint f2ui(float x) {
|
||||||
|
uint y = floatBitsToUint(x);
|
||||||
|
if ((y & 0x80000000) != 0) {
|
||||||
|
y ^= ~0;
|
||||||
|
} else {
|
||||||
|
y |= 0x80000000;
|
||||||
|
}
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
void topk(const uint row) {
|
||||||
|
const int tid = int(gl_LocalInvocationID.x);
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
|
if (p.first_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_input;
|
||||||
|
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols;
|
||||||
|
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.ncols_output == 1) {
|
||||||
|
// Fast path for single output - just do a max reduction
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
|
if (tid < s) {
|
||||||
|
ivec2 a = dst_row[tid];
|
||||||
|
ivec2 b = dst_row[tid + s];
|
||||||
|
if (a.x >= p.orig_ncols ||
|
||||||
|
b.x < p.orig_ncols && b.y > a.y) {
|
||||||
|
dst_row[tid] = b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Do an N-ary search to find the K-th largest value.
|
||||||
|
// We remap the float values to be comparable as unsigned integers,
|
||||||
|
// and split the range into 2^N smaller ranges where N is the
|
||||||
|
// subgroup size. Count how many values are in each range, if the K-th
|
||||||
|
// largest value is in the middle of one of thee ranges then repeat
|
||||||
|
// and split again.
|
||||||
|
|
||||||
|
// Mask is the current set of bits we're searching. Shift is the LSB index.
|
||||||
|
int shift = 32 - SUBGROUP_SIZE_LOG2;
|
||||||
|
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
|
||||||
|
|
||||||
|
// The current range.
|
||||||
|
uint range_min = 0;
|
||||||
|
uint range_max = 0xFF800000;
|
||||||
|
// How many are above the current range, and how many we need to find.
|
||||||
|
uint total = 0;
|
||||||
|
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
||||||
|
|
||||||
|
while (mask != 0) {
|
||||||
|
barrier();
|
||||||
|
// Initialize bucket counts to zero.
|
||||||
|
if (tid < SUBGROUP_SIZE) {
|
||||||
|
counts[tid] = 0;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
// Count how many values are in each bucket.
|
||||||
|
if (tid < p.ncols_input) {
|
||||||
|
float y = intBitsToFloat(dst_row[tid].y);
|
||||||
|
uint fy = f2ui(y);
|
||||||
|
if (fy >= range_min && fy < range_max) {
|
||||||
|
uint bucket = (fy & mask) >> shift;
|
||||||
|
atomicAdd(counts[bucket], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
// On the first subgroup, do a scan to count (from the top down) how
|
||||||
|
// many elements are in the top N buckets. Find the index of the first
|
||||||
|
// that is over the limit. Copy it to the other invocations through
|
||||||
|
// shared memory.
|
||||||
|
if (tid < SUBGROUP_SIZE) {
|
||||||
|
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
|
||||||
|
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
|
||||||
|
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
|
||||||
|
if (tid == t) {
|
||||||
|
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
|
||||||
|
sh_total = partial_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
int min_idx = sh_min_idx;
|
||||||
|
total = sh_total;
|
||||||
|
|
||||||
|
// Update the range, and break if we've found the K-th largest.
|
||||||
|
range_max = range_min + ((min_idx + 1) << shift);
|
||||||
|
range_min = range_min + (min_idx << shift);
|
||||||
|
|
||||||
|
if (total == p.ncols_output) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
total -= counts[min_idx];
|
||||||
|
mask >>= SUBGROUP_SIZE_LOG2;
|
||||||
|
shift -= SUBGROUP_SIZE_LOG2;
|
||||||
|
if (shift < 0) {
|
||||||
|
shift = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ivec2 v = dst_row[tid];
|
||||||
|
|
||||||
|
// We need to compact these values to the start of the dst_row array.
|
||||||
|
// Have each subgroup count how many items it'll store, so other
|
||||||
|
// subgroups can compute their base offset.
|
||||||
|
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||||
|
uvec4 b = subgroupBallot(top);
|
||||||
|
uint bit_count = subgroupBallotBitCount(b);
|
||||||
|
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||||
|
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint out_idx = 0;
|
||||||
|
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||||
|
if (i < tid / SUBGROUP_SIZE) {
|
||||||
|
out_idx += offset_partials[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
||||||
|
if (top) {
|
||||||
|
// TODO: Copy directly to the output?
|
||||||
|
dst_row[out_idx + bit_count_ex] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||||
|
if (p.last_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_output;
|
||||||
|
data_d[row_offset + tid] = dst_row[tid].x;
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||||
|
data_t[row_offset + tid] = dst_row[tid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "rte.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
#include "generic_unary_head.glsl"
|
||||||
|
|
||||||
|
#define GGML_TRI_TYPE_UPPER_DIAG 0
|
||||||
|
#define GGML_TRI_TYPE_UPPER 1
|
||||||
|
#define GGML_TRI_TYPE_LOWER_DIAG 2
|
||||||
|
#define GGML_TRI_TYPE_LOWER 3
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint idx = get_idx();
|
||||||
|
|
||||||
|
if (idx >= p.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
|
||||||
|
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
||||||
|
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
|
||||||
|
const uint i02_offset = i02*p.ne01*p.ne00;
|
||||||
|
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
|
||||||
|
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
|
||||||
|
|
||||||
|
int param = floatBitsToInt(p.param1);
|
||||||
|
bool pass = false;
|
||||||
|
switch (param) {
|
||||||
|
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
|
||||||
|
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
|
||||||
|
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
|
||||||
|
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pass) {
|
||||||
|
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
|
||||||
|
} else {
|
||||||
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -846,6 +846,9 @@ void process_shaders() {
|
||||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
|
@ -913,6 +916,9 @@ void process_shaders() {
|
||||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
|
||||||
|
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||||
|
|
@ -941,6 +947,8 @@ void process_shaders() {
|
||||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
|
string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
for (auto transpose : {false, true}) {
|
for (auto transpose : {false, true}) {
|
||||||
for (auto unroll : {false, true}) {
|
for (auto unroll : {false, true}) {
|
||||||
for (auto a_f16 : {false, true}) {
|
for (auto a_f16 : {false, true}) {
|
||||||
|
|
|
||||||
|
|
@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
QWEN2VL = auto()
|
QWEN2VL = auto()
|
||||||
QWEN3 = auto()
|
QWEN3 = auto()
|
||||||
QWEN3MOE = auto()
|
QWEN3MOE = auto()
|
||||||
|
QWEN3NEXT = auto()
|
||||||
QWEN3VL = auto()
|
QWEN3VL = auto()
|
||||||
QWEN3VLMOE = auto()
|
QWEN3VLMOE = auto()
|
||||||
PHI2 = auto()
|
PHI2 = auto()
|
||||||
|
|
@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
SSM_D = auto()
|
SSM_D = auto()
|
||||||
SSM_NORM = auto()
|
SSM_NORM = auto()
|
||||||
SSM_OUT = auto()
|
SSM_OUT = auto()
|
||||||
|
SSM_BETA_ALPHA = auto() # qwen3next
|
||||||
TIME_MIX_W0 = auto()
|
TIME_MIX_W0 = auto()
|
||||||
TIME_MIX_W1 = auto()
|
TIME_MIX_W1 = auto()
|
||||||
TIME_MIX_W2 = auto()
|
TIME_MIX_W2 = auto()
|
||||||
|
|
@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||||
MODEL_ARCH.QWEN3: "qwen3",
|
MODEL_ARCH.QWEN3: "qwen3",
|
||||||
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
||||||
|
MODEL_ARCH.QWEN3NEXT: "qwen3next",
|
||||||
MODEL_ARCH.QWEN3VL: "qwen3vl",
|
MODEL_ARCH.QWEN3VL: "qwen3vl",
|
||||||
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
|
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
|
||||||
MODEL_ARCH.PHI2: "phi2",
|
MODEL_ARCH.PHI2: "phi2",
|
||||||
|
|
@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||||
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
||||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||||
|
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
|
||||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||||
|
|
@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.QWEN3NEXT: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_POST_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.SSM_A,
|
||||||
|
MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
MODEL_TENSOR.SSM_DT,
|
||||||
|
MODEL_TENSOR.SSM_NORM,
|
||||||
|
MODEL_TENSOR.SSM_IN,
|
||||||
|
MODEL_TENSOR.SSM_BETA_ALPHA,
|
||||||
|
MODEL_TENSOR.SSM_OUT
|
||||||
|
],
|
||||||
MODEL_ARCH.QWEN3VL: [
|
MODEL_ARCH.QWEN3VL: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,11 @@ import gguf
|
||||||
logger = logging.getLogger("gguf-convert-endian")
|
logger = logging.getLogger("gguf-convert-endian")
|
||||||
|
|
||||||
|
|
||||||
|
def byteswap_noop(tensor, block_offs):
|
||||||
|
# this function is used when byteswapping is not needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def byteswap_q4_0(tensor, block_offs):
|
def byteswap_q4_0(tensor, block_offs):
|
||||||
# Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations.
|
# Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations.
|
||||||
|
|
||||||
|
|
@ -55,22 +60,11 @@ def byteswap_q6_k(tensor, block_offs):
|
||||||
|
|
||||||
|
|
||||||
byteswap_tensors = {
|
byteswap_tensors = {
|
||||||
gguf.GGMLQuantizationType.Q4_0: {
|
gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0,
|
||||||
"block_size": 18, # 18 bytes = <f16 delta scaling factor> + 16 * <int8 quant>
|
gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0,
|
||||||
"byteswap_func": byteswap_q4_0,
|
gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k,
|
||||||
},
|
gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k,
|
||||||
gguf.GGMLQuantizationType.Q8_0: {
|
gguf.GGMLQuantizationType.MXFP4: byteswap_noop,
|
||||||
"block_size": 34, # 34 bytes = <f16 delta scaling factor> + 32 * <int8 quant>
|
|
||||||
"byteswap_func": byteswap_q8_0,
|
|
||||||
},
|
|
||||||
gguf.GGMLQuantizationType.Q4_K: {
|
|
||||||
"block_size": 144, # 144 bytes = 2 * <f16 delta scaling factor> + 140 * <int8 quant>
|
|
||||||
"byteswap_func": byteswap_q4_k,
|
|
||||||
},
|
|
||||||
gguf.GGMLQuantizationType.Q6_K: {
|
|
||||||
"block_size": 210, # 210 bytes = <f16 delta scaling factor> + 208 * <int8 quant>
|
|
||||||
"byteswap_func": byteswap_q6_k,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -135,8 +129,8 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
||||||
|
|
||||||
tensor.data.resize(newshape)
|
tensor.data.resize(newshape)
|
||||||
|
|
||||||
block_size = byteswap_tensors[tensor.tensor_type]["block_size"]
|
block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1]
|
||||||
byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"]
|
byteswap_func = byteswap_tensors[tensor.tensor_type]
|
||||||
|
|
||||||
n_blocks = len(tensor.data) // block_size
|
n_blocks = len(tensor.data) // block_size
|
||||||
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
|
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
|
||||||
|
|
|
||||||
|
|
@ -672,10 +672,11 @@ class TensorNameMap:
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
"model.layers.{bid}.in_proj", # mamba-hf
|
"model.layers.{bid}.in_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||||
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
||||||
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
||||||
|
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_CONV1D: (
|
MODEL_TENSOR.SSM_CONV1D: (
|
||||||
|
|
@ -683,6 +684,7 @@ class TensorNameMap:
|
||||||
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
||||||
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
|
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
|
||||||
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
|
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
|
||||||
|
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_X: (
|
MODEL_TENSOR.SSM_X: (
|
||||||
|
|
@ -697,6 +699,7 @@ class TensorNameMap:
|
||||||
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
||||||
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
|
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
|
||||||
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
||||||
|
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_DT_NORM: (
|
MODEL_TENSOR.SSM_DT_NORM: (
|
||||||
|
|
@ -709,6 +712,7 @@ class TensorNameMap:
|
||||||
"backbone.layers.{bid}.mixer.A_log", # mamba
|
"backbone.layers.{bid}.mixer.A_log", # mamba
|
||||||
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
|
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
|
||||||
"model.layers.layers.{bid}.mixer.A_log", # plamo2
|
"model.layers.layers.{bid}.mixer.A_log", # plamo2
|
||||||
|
"model.layers.{bid}.linear_attn.A_log", # qwen3next
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_B_NORM: (
|
MODEL_TENSOR.SSM_B_NORM: (
|
||||||
|
|
@ -731,17 +735,23 @@ class TensorNameMap:
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_NORM: (
|
MODEL_TENSOR.SSM_NORM: (
|
||||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
"model.layers.{bid}.linear_attn.norm", # qwen3next
|
||||||
|
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_OUT: (
|
MODEL_TENSOR.SSM_OUT: (
|
||||||
"model.layers.{bid}.out_proj", # mamba-hf
|
"model.layers.{bid}.out_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||||
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
|
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
|
||||||
|
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
|
||||||
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
|
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_BETA_ALPHA: (
|
||||||
|
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_W0: (
|
MODEL_TENSOR.TIME_MIX_W0: (
|
||||||
"model.layers.{bid}.attention.w0", # rwkv7
|
"model.layers.{bid}.attention.w0", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ add_library(llama
|
||||||
models/qwen3vl.cpp
|
models/qwen3vl.cpp
|
||||||
models/qwen3vl-moe.cpp
|
models/qwen3vl-moe.cpp
|
||||||
models/qwen3moe.cpp
|
models/qwen3moe.cpp
|
||||||
|
models/qwen3next.cpp
|
||||||
models/refact.cpp
|
models/refact.cpp
|
||||||
models/rnd1.cpp
|
models/rnd1.cpp
|
||||||
models/rwkv6-base.cpp
|
models/rwkv6-base.cpp
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||||
{ LLM_ARCH_QWEN3, "qwen3" },
|
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||||
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||||
|
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
|
||||||
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
|
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
|
||||||
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
|
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
|
||||||
{ LLM_ARCH_PHI2, "phi2" },
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
|
|
@ -829,6 +830,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_QWEN3NEXT,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||||
|
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||||
|
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||||
|
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
|
||||||
|
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||||
|
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||||
|
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_QWEN3VL,
|
LLM_ARCH_QWEN3VL,
|
||||||
{
|
{
|
||||||
|
|
@ -2237,7 +2270,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
||||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
|
||||||
{ LLM_TENSOR_OUTPUT, "output" },
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -2259,7 +2292,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
||||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
|
||||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
|
@ -2487,11 +2520,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// declare information about the model weight tensors:
|
||||||
|
// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight
|
||||||
|
// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator
|
||||||
|
//
|
||||||
|
// for example, input layers are usually assigned to CPU/host buffer types
|
||||||
|
//
|
||||||
|
// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal
|
||||||
|
// assignment of the buffer types and extra overhead during computation
|
||||||
|
// example: https://github.com/ggml-org/llama.cpp/pull/17548
|
||||||
|
//
|
||||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||||
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
|
||||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||||
|
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||||
|
|
@ -2546,6 +2589,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
|
@ -2744,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
case LLM_ARCH_LFM2MOE:
|
case LLM_ARCH_LFM2MOE:
|
||||||
case LLM_ARCH_NEMOTRON_H:
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_QWEN2VL,
|
LLM_ARCH_QWEN2VL,
|
||||||
LLM_ARCH_QWEN3,
|
LLM_ARCH_QWEN3,
|
||||||
LLM_ARCH_QWEN3MOE,
|
LLM_ARCH_QWEN3MOE,
|
||||||
|
LLM_ARCH_QWEN3NEXT,
|
||||||
LLM_ARCH_QWEN3VL,
|
LLM_ARCH_QWEN3VL,
|
||||||
LLM_ARCH_QWEN3VLMOE,
|
LLM_ARCH_QWEN3VLMOE,
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
|
|
@ -381,6 +382,7 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_SSM_D,
|
LLM_TENSOR_SSM_D,
|
||||||
LLM_TENSOR_SSM_NORM,
|
LLM_TENSOR_SSM_NORM,
|
||||||
LLM_TENSOR_SSM_OUT,
|
LLM_TENSOR_SSM_OUT,
|
||||||
|
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
|
||||||
LLM_TENSOR_TIME_MIX_W0,
|
LLM_TENSOR_TIME_MIX_W0,
|
||||||
LLM_TENSOR_TIME_MIX_W1,
|
LLM_TENSOR_TIME_MIX_W1,
|
||||||
LLM_TENSOR_TIME_MIX_W2,
|
LLM_TENSOR_TIME_MIX_W2,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#include "llama-context.h"
|
#include "llama-context.h"
|
||||||
|
|
||||||
|
#include "llama-arch.h"
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-io.h"
|
#include "llama-io.h"
|
||||||
|
|
@ -1386,6 +1387,9 @@ void llama_context::output_reorder() {
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t llama_context::graph_max_nodes() const {
|
uint32_t llama_context::graph_max_nodes() const {
|
||||||
|
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
||||||
|
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
|
||||||
|
}
|
||||||
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
// bump if necessary
|
// bump if necessary
|
||||||
#define LLAMA_MAX_LAYERS 512
|
#define LLAMA_MAX_LAYERS 512
|
||||||
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
|
||||||
|
|
||||||
enum llama_expert_gating_func_type {
|
enum llama_expert_gating_func_type {
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-mmap.h"
|
#include "llama-mmap.h"
|
||||||
#include "llama-batch.h"
|
|
||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
#include "llama-model-loader.h"
|
#include "llama-model-loader.h"
|
||||||
|
|
||||||
|
|
@ -2225,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
// Load linear attention (gated delta net) parameters
|
||||||
|
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||||
|
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||||
|
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||||
|
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||||
|
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||||
|
|
||||||
|
// Mark recurrent layers (linear attention layers)
|
||||||
|
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||||
|
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 80: type = LLM_TYPE_80B_A3B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default: throw std::runtime_error("unsupported model architecture");
|
default: throw std::runtime_error("unsupported model architecture");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -6133,9 +6155,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
case LLM_ARCH_LFM2MOE:
|
case LLM_ARCH_LFM2MOE:
|
||||||
{
|
{
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
|
||||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
if (output == NULL) {
|
if (output == NULL) {
|
||||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||||
|
|
@ -6414,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
|
{
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
// if output is NULL, init from the input tok embed
|
||||||
|
if (output == NULL) {
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||||
|
|
||||||
|
// Calculate dimensions from hyperparameters
|
||||||
|
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||||
|
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||||
|
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||||
|
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||||
|
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||||
|
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||||
|
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||||
|
|
||||||
|
// Calculate projection sizes
|
||||||
|
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
|
||||||
|
const int64_t ba_dim = n_v_heads * 2;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||||
|
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||||
|
|
||||||
|
if (!hparams.is_recurrent(i)) {
|
||||||
|
// Attention layers
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||||
|
|
||||||
|
// Q/K normalization for attention layers
|
||||||
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||||
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||||
|
} else {
|
||||||
|
// Linear attention (gated delta net) specific tensors
|
||||||
|
// Create tensors with calculated dimensions
|
||||||
|
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
|
||||||
|
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||||
|
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||||
|
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
|
||||||
|
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
|
||||||
|
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||||
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||||
|
|
||||||
|
// Shared experts
|
||||||
|
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
|
||||||
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
||||||
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
||||||
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
|
|
@ -6684,6 +6775,7 @@ void llama_model::print_info() const {
|
||||||
arch == LLM_ARCH_FALCON_H1 ||
|
arch == LLM_ARCH_FALCON_H1 ||
|
||||||
arch == LLM_ARCH_PLAMO2 ||
|
arch == LLM_ARCH_PLAMO2 ||
|
||||||
arch == LLM_ARCH_GRANITE_HYBRID ||
|
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||||
|
arch == LLM_ARCH_QWEN3NEXT ||
|
||||||
arch == LLM_ARCH_NEMOTRON_H) {
|
arch == LLM_ARCH_NEMOTRON_H) {
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||||
|
|
@ -7425,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
case LLM_ARCH_PANGU_EMBED:
|
case LLM_ARCH_PANGU_EMBED:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
|
llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
|
||||||
}break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
@ -7652,6 +7748,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_COGVLM:
|
case LLM_ARCH_COGVLM:
|
||||||
case LLM_ARCH_PANGU_EMBED:
|
case LLM_ARCH_PANGU_EMBED:
|
||||||
case LLM_ARCH_AFMOE:
|
case LLM_ARCH_AFMOE:
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
return LLAMA_ROPE_TYPE_NEOX;
|
return LLAMA_ROPE_TYPE_NEOX;
|
||||||
|
|
||||||
case LLM_ARCH_QWEN2VL:
|
case LLM_ARCH_QWEN2VL:
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,7 @@ enum llm_type {
|
||||||
LLM_TYPE_16B_A1B,
|
LLM_TYPE_16B_A1B,
|
||||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||||
LLM_TYPE_30B_A3B,
|
LLM_TYPE_30B_A3B,
|
||||||
|
LLM_TYPE_80B_A3B, // Qwen3 Next
|
||||||
LLM_TYPE_100B_A6B,
|
LLM_TYPE_100B_A6B,
|
||||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||||
LLM_TYPE_230B_A10B, // Minimax M2
|
LLM_TYPE_230B_A10B, // Minimax M2
|
||||||
|
|
@ -309,6 +310,9 @@ struct llama_layer {
|
||||||
struct ggml_tensor * ssm_conv1d_b = nullptr;
|
struct ggml_tensor * ssm_conv1d_b = nullptr;
|
||||||
struct ggml_tensor * ssm_dt_b = nullptr;
|
struct ggml_tensor * ssm_dt_b = nullptr;
|
||||||
|
|
||||||
|
// qwen3next
|
||||||
|
struct ggml_tensor * ssm_beta_alpha = nullptr;
|
||||||
|
|
||||||
// rwkv
|
// rwkv
|
||||||
struct ggml_tensor * time_mix_w1 = nullptr;
|
struct ggml_tensor * time_mix_w1 = nullptr;
|
||||||
struct ggml_tensor * time_mix_w2 = nullptr;
|
struct ggml_tensor * time_mix_w2 = nullptr;
|
||||||
|
|
|
||||||
|
|
@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
}
|
}
|
||||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||||
continue;
|
continue;
|
||||||
} else if (remapped_name != it.first) {
|
}
|
||||||
|
|
||||||
|
if (remapped_name != it.first) {
|
||||||
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
||||||
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
||||||
}
|
}
|
||||||
|
|
@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
{
|
{
|
||||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||||
// attention layers have a non-zero number of kv heads
|
// attention layers have a non-zero number of kv heads
|
||||||
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||||
if (llama_model_has_encoder(&model)) {
|
if (llama_model_has_encoder(&model)) {
|
||||||
// now n_attn_layer is the number of attention layers in the encoder
|
// now n_layer_attn is the number of attention layers in the encoder
|
||||||
// for each decoder block, there are 2 attention layers
|
// for each decoder block, there are 2 attention layers
|
||||||
n_attn_layer += 2 * model.hparams.dec_n_layer;
|
n_layer_attn += 2 * model.hparams.dec_n_layer;
|
||||||
}
|
}
|
||||||
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
|
||||||
|
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||||
|
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
|
||||||
|
|
||||||
|
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t total_size_org = 0;
|
size_t total_size_org = 0;
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
|
||||||
ggml_tensor * cur = build_inp_embd(model.tok_embd);
|
ggml_tensor * cur = build_inp_embd(model.tok_embd);
|
||||||
cb(cur, "model.embed_tokens", -1);
|
cb(cur, "model.embed_tokens", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
auto * inp_hybrid = build_inp_mem_hybrid();
|
auto * inp_hybrid = build_inp_mem_hybrid();
|
||||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
@ -40,12 +42,12 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
|
||||||
cur = ggml_add(ctx0, cur, ffn_out);
|
cur = ggml_add(ctx0, cur, ffn_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
|
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||||
cb(cur, "model.embedding_norm", -1);
|
cb(cur, "result_norm", -1);
|
||||||
res->t_embd = cur;
|
res->t_embd = cur;
|
||||||
|
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
cb(cur, "lm_head", -1);
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
res->t_logits = cur;
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,9 @@
|
||||||
|
|
||||||
#include "../llama-model.h"
|
#include "../llama-model.h"
|
||||||
#include "../llama-graph.h"
|
#include "../llama-graph.h"
|
||||||
#include "../llama-memory-recurrent.h"
|
|
||||||
|
|
||||||
|
// TODO: remove in follow-up PR - move to .cpp files
|
||||||
|
#include "../llama-memory-recurrent.h"
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
struct llm_graph_context_mamba : public llm_graph_context {
|
struct llm_graph_context_mamba : public llm_graph_context {
|
||||||
|
|
@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context {
|
||||||
struct llm_build_qwen3vlmoe : public llm_graph_context {
|
struct llm_build_qwen3vlmoe : public llm_graph_context {
|
||||||
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
|
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
|
||||||
};
|
};
|
||||||
|
struct llm_build_qwen3next : public llm_graph_context_mamba {
|
||||||
|
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
|
||||||
|
private:
|
||||||
|
ggml_tensor * build_layer_attn(
|
||||||
|
llm_graph_input_attn_kv * inp_attn,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * inp_pos,
|
||||||
|
int il);
|
||||||
|
|
||||||
|
ggml_tensor * build_layer_attn_linear(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * causal_mask,
|
||||||
|
ggml_tensor * identity,
|
||||||
|
int il);
|
||||||
|
|
||||||
|
ggml_tensor * build_layer_ffn(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
int il);
|
||||||
|
|
||||||
|
ggml_tensor * build_delta_net_recurrent(
|
||||||
|
ggml_tensor * q,
|
||||||
|
ggml_tensor * k,
|
||||||
|
ggml_tensor * v,
|
||||||
|
ggml_tensor * g,
|
||||||
|
ggml_tensor * beta,
|
||||||
|
ggml_tensor * state,
|
||||||
|
ggml_tensor * causal_mask,
|
||||||
|
ggml_tensor * identity,
|
||||||
|
int il);
|
||||||
|
|
||||||
|
ggml_tensor * build_delta_net_chunking(
|
||||||
|
ggml_tensor * q,
|
||||||
|
ggml_tensor * k,
|
||||||
|
ggml_tensor * v,
|
||||||
|
ggml_tensor * g,
|
||||||
|
ggml_tensor * beta,
|
||||||
|
ggml_tensor * state,
|
||||||
|
ggml_tensor * causal_mask,
|
||||||
|
ggml_tensor * identity,
|
||||||
|
int il);
|
||||||
|
|
||||||
|
ggml_tensor * build_norm_gated(
|
||||||
|
ggml_tensor * input,
|
||||||
|
ggml_tensor * weights,
|
||||||
|
ggml_tensor * gate,
|
||||||
|
int layer);
|
||||||
|
|
||||||
|
const llama_model & model;
|
||||||
|
};
|
||||||
|
|
||||||
struct llm_build_qwen : public llm_graph_context {
|
struct llm_build_qwen : public llm_graph_context {
|
||||||
llm_build_qwen(const llama_model & model, const llm_graph_params & params);
|
llm_build_qwen(const llama_model & model, const llm_graph_params & params);
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -7635,6 +7635,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; ++i) {
|
||||||
|
for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
|
||||||
|
if (k <= 1<<i) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int k : {1, 2, 3, 7, 15}) {
|
for (int k : {1, 2, 3, 7, 15}) {
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
|
||||||
|
|
@ -7927,6 +7935,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
||||||
|
|
||||||
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
|
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
|
||||||
for (ggml_type type_a : all_types) {
|
for (ggml_type type_a : all_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
|
@ -8032,7 +8043,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
|
for (auto k : {1, 10, 40}) {
|
||||||
|
for (auto nrows : {1, 16}) {
|
||||||
|
for (auto cols : {k, 1000, 65000, 200000}) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return test_cases;
|
return test_cases;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1175,10 +1175,11 @@ struct clip_graph {
|
||||||
cb(K, "resampler_K", -1);
|
cb(K, "resampler_K", -1);
|
||||||
cb(V, "resampler_V", -1);
|
cb(V, "resampler_V", -1);
|
||||||
|
|
||||||
|
float resampler_kq_scale = 1.0f/ sqrtf(float(d_head));
|
||||||
embeddings = build_attn(
|
embeddings = build_attn(
|
||||||
model.mm_model_attn_o_w,
|
model.mm_model_attn_o_w,
|
||||||
model.mm_model_attn_o_b,
|
model.mm_model_attn_o_b,
|
||||||
Q, K, V, nullptr, kq_scale, -1);
|
Q, K, V, nullptr, resampler_kq_scale, -1);
|
||||||
cb(embeddings, "resampler_attn_out", -1);
|
cb(embeddings, "resampler_attn_out", -1);
|
||||||
}
|
}
|
||||||
// layernorm
|
// layernorm
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
|
||||||
**Features:**
|
**Features:**
|
||||||
* LLM inference of F16 and quantized models on GPU and CPU
|
* LLM inference of F16 and quantized models on GPU and CPU
|
||||||
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
|
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
|
||||||
|
* [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions
|
||||||
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
|
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
|
||||||
* Parallel decoding with multi-user support
|
* Parallel decoding with multi-user support
|
||||||
* Continuous batching
|
* Continuous batching
|
||||||
|
|
@ -200,6 +201,8 @@ The project is under active development, and we are [looking for feedback and co
|
||||||
| `--models-allow-extra-args` | for router server, allow extra arguments for models; important: some arguments can allow users to access local file system, use with caution (default: disabled)<br/>(env: LLAMA_ARG_MODELS_ALLOW_EXTRA_ARGS) |
|
| `--models-allow-extra-args` | for router server, allow extra arguments for models; important: some arguments can allow users to access local file system, use with caution (default: disabled)<br/>(env: LLAMA_ARG_MODELS_ALLOW_EXTRA_ARGS) |
|
||||||
| `--no-models-autoload` | disables automatic loading of models (default: enabled)<br/>(env: LLAMA_ARG_NO_MODELS_AUTOLOAD) |
|
| `--no-models-autoload` | disables automatic loading of models (default: enabled)<br/>(env: LLAMA_ARG_NO_MODELS_AUTOLOAD) |
|
||||||
| `--jinja` | use jinja template for chat (default: disabled)<br/>(env: LLAMA_ARG_JINJA) |
|
| `--jinja` | use jinja template for chat (default: disabled)<br/>(env: LLAMA_ARG_JINJA) |
|
||||||
|
| `--jinja` | use jinja template for chat (default: enabled)<br/><br/>(env: LLAMA_ARG_JINJA) |
|
||||||
|
| `--no-jinja` | disable jinja template for chat (default: enabled)<br/><br/>(env: LLAMA_ARG_NO_JINJA) |
|
||||||
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) |
|
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) |
|
||||||
| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
|
| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
|
||||||
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
|
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
|
||||||
|
|
@ -1355,6 +1358,76 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### POST `/v1/messages`: Anthropic-compatible Messages API
|
||||||
|
|
||||||
|
Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps.
|
||||||
|
|
||||||
|
*Options:*
|
||||||
|
|
||||||
|
See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag.
|
||||||
|
|
||||||
|
`model`: Model identifier (required)
|
||||||
|
|
||||||
|
`messages`: Array of message objects with `role` and `content` (required)
|
||||||
|
|
||||||
|
`max_tokens`: Maximum tokens to generate (default: 4096)
|
||||||
|
|
||||||
|
`system`: System prompt as string or array of content blocks
|
||||||
|
|
||||||
|
`temperature`: Sampling temperature 0-1 (default: 1.0)
|
||||||
|
|
||||||
|
`top_p`: Nucleus sampling (default: 1.0)
|
||||||
|
|
||||||
|
`top_k`: Top-k sampling
|
||||||
|
|
||||||
|
`stop_sequences`: Array of stop sequences
|
||||||
|
|
||||||
|
`stream`: Enable streaming (default: false)
|
||||||
|
|
||||||
|
`tools`: Array of tool definitions (requires `--jinja`)
|
||||||
|
|
||||||
|
`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`)
|
||||||
|
|
||||||
|
*Examples:*
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:8080/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "x-api-key: your-api-key" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"system": "You are a helpful assistant.",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello!"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST `/v1/messages/count_tokens`: Token Counting
|
||||||
|
|
||||||
|
Counts the number of tokens in a request without generating a response.
|
||||||
|
|
||||||
|
Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required.
|
||||||
|
|
||||||
|
*Example:*
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:8080/v1/messages/count_tokens \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello!"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
*Response:*
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"input_tokens": 10}
|
||||||
|
```
|
||||||
|
|
||||||
## Using multiple models
|
## Using multiple models
|
||||||
|
|
||||||
|
|
@ -1513,6 +1586,7 @@ Response:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### POST `/models/unload`: Unload a model
|
### POST `/models/unload`: Unload a model
|
||||||
|
|
||||||
Unload a model
|
Unload a model
|
||||||
|
|
|
||||||
|
|
@ -725,7 +725,6 @@ std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtm
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// OAI utils
|
// OAI utils
|
||||||
//
|
//
|
||||||
|
|
@ -1048,6 +1047,222 @@ json oaicompat_chat_params_parse(
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json convert_anthropic_to_oai(const json & body) {
|
||||||
|
json oai_body;
|
||||||
|
|
||||||
|
// Convert system prompt
|
||||||
|
json oai_messages = json::array();
|
||||||
|
auto system_param = json_value(body, "system", json());
|
||||||
|
if (!system_param.is_null()) {
|
||||||
|
std::string system_content;
|
||||||
|
|
||||||
|
if (system_param.is_string()) {
|
||||||
|
system_content = system_param.get<std::string>();
|
||||||
|
} else if (system_param.is_array()) {
|
||||||
|
for (const auto & block : system_param) {
|
||||||
|
if (json_value(block, "type", std::string()) == "text") {
|
||||||
|
system_content += json_value(block, "text", std::string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
oai_messages.push_back({
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_content}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert messages
|
||||||
|
if (!body.contains("messages")) {
|
||||||
|
throw std::runtime_error("'messages' is required");
|
||||||
|
}
|
||||||
|
const json & messages = body.at("messages");
|
||||||
|
if (messages.is_array()) {
|
||||||
|
for (const auto & msg : messages) {
|
||||||
|
std::string role = json_value(msg, "role", std::string());
|
||||||
|
|
||||||
|
if (!msg.contains("content")) {
|
||||||
|
if (role == "assistant") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
oai_messages.push_back(msg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const json & content = msg.at("content");
|
||||||
|
|
||||||
|
if (content.is_string()) {
|
||||||
|
oai_messages.push_back(msg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!content.is_array()) {
|
||||||
|
oai_messages.push_back(msg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
json tool_calls = json::array();
|
||||||
|
json converted_content = json::array();
|
||||||
|
json tool_results = json::array();
|
||||||
|
bool has_tool_calls = false;
|
||||||
|
|
||||||
|
for (const auto & block : content) {
|
||||||
|
std::string type = json_value(block, "type", std::string());
|
||||||
|
|
||||||
|
if (type == "text") {
|
||||||
|
converted_content.push_back(block);
|
||||||
|
} else if (type == "image") {
|
||||||
|
json source = json_value(block, "source", json::object());
|
||||||
|
std::string source_type = json_value(source, "type", std::string());
|
||||||
|
|
||||||
|
if (source_type == "base64") {
|
||||||
|
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||||
|
std::string data = json_value(source, "data", std::string());
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "data:" << media_type << ";base64," << data;
|
||||||
|
|
||||||
|
converted_content.push_back({
|
||||||
|
{"type", "image_url"},
|
||||||
|
{"image_url", {
|
||||||
|
{"url", ss.str()}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
} else if (source_type == "url") {
|
||||||
|
std::string url = json_value(source, "url", std::string());
|
||||||
|
converted_content.push_back({
|
||||||
|
{"type", "image_url"},
|
||||||
|
{"image_url", {
|
||||||
|
{"url", url}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (type == "tool_use") {
|
||||||
|
tool_calls.push_back({
|
||||||
|
{"id", json_value(block, "id", std::string())},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", json_value(block, "name", std::string())},
|
||||||
|
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
has_tool_calls = true;
|
||||||
|
} else if (type == "tool_result") {
|
||||||
|
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||||
|
|
||||||
|
auto result_content = json_value(block, "content", json());
|
||||||
|
std::string result_text;
|
||||||
|
if (result_content.is_string()) {
|
||||||
|
result_text = result_content.get<std::string>();
|
||||||
|
} else if (result_content.is_array()) {
|
||||||
|
for (const auto & c : result_content) {
|
||||||
|
if (json_value(c, "type", std::string()) == "text") {
|
||||||
|
result_text += json_value(c, "text", std::string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_results.push_back({
|
||||||
|
{"role", "tool"},
|
||||||
|
{"tool_call_id", tool_use_id},
|
||||||
|
{"content", result_text}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!converted_content.empty() || has_tool_calls) {
|
||||||
|
json new_msg = {{"role", role}};
|
||||||
|
if (!converted_content.empty()) {
|
||||||
|
new_msg["content"] = converted_content;
|
||||||
|
} else if (has_tool_calls) {
|
||||||
|
new_msg["content"] = "";
|
||||||
|
}
|
||||||
|
if (!tool_calls.empty()) {
|
||||||
|
new_msg["tool_calls"] = tool_calls;
|
||||||
|
}
|
||||||
|
oai_messages.push_back(new_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & tool_msg : tool_results) {
|
||||||
|
oai_messages.push_back(tool_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
oai_body["messages"] = oai_messages;
|
||||||
|
|
||||||
|
// Convert tools
|
||||||
|
if (body.contains("tools")) {
|
||||||
|
const json & tools = body.at("tools");
|
||||||
|
if (tools.is_array()) {
|
||||||
|
json oai_tools = json::array();
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
oai_tools.push_back({
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", json_value(tool, "name", std::string())},
|
||||||
|
{"description", json_value(tool, "description", std::string())},
|
||||||
|
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
oai_body["tools"] = oai_tools;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tool_choice
|
||||||
|
if (body.contains("tool_choice")) {
|
||||||
|
const json & tc = body.at("tool_choice");
|
||||||
|
if (tc.is_object()) {
|
||||||
|
std::string type = json_value(tc, "type", std::string());
|
||||||
|
if (type == "auto") {
|
||||||
|
oai_body["tool_choice"] = "auto";
|
||||||
|
} else if (type == "any" || type == "tool") {
|
||||||
|
oai_body["tool_choice"] = "required";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert stop_sequences to stop
|
||||||
|
if (body.contains("stop_sequences")) {
|
||||||
|
oai_body["stop"] = body.at("stop_sequences");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle max_tokens (required in Anthropic, but we're permissive)
|
||||||
|
if (body.contains("max_tokens")) {
|
||||||
|
oai_body["max_tokens"] = body.at("max_tokens");
|
||||||
|
} else {
|
||||||
|
oai_body["max_tokens"] = 4096;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass through common params
|
||||||
|
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
|
||||||
|
if (body.contains(key)) {
|
||||||
|
oai_body[key] = body.at(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Anthropic-specific thinking param
|
||||||
|
if (body.contains("thinking")) {
|
||||||
|
json thinking = json_value(body, "thinking", json::object());
|
||||||
|
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||||
|
if (thinking_type == "enabled") {
|
||||||
|
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||||
|
oai_body["thinking_budget_tokens"] = budget_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Anthropic-specific metadata param
|
||||||
|
if (body.contains("metadata")) {
|
||||||
|
json metadata = json_value(body, "metadata", json::object());
|
||||||
|
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||||
|
if (!user_id.empty()) {
|
||||||
|
oai_body["__metadata_user_id"] = user_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return oai_body;
|
||||||
|
}
|
||||||
|
|
||||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
|
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
|
||||||
json data = json::array();
|
json data = json::array();
|
||||||
int32_t n_tokens = 0;
|
int32_t n_tokens = 0;
|
||||||
|
|
@ -1211,7 +1426,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
|
||||||
|
|
||||||
// format server-sent event (SSE), return the formatted string to send
|
// format server-sent event (SSE), return the formatted string to send
|
||||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||||
std::string format_sse(const json & data) {
|
std::string format_oai_sse(const json & data) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto send_single = [&ss](const json & data) {
|
auto send_single = [&ss](const json & data) {
|
||||||
ss << "data: " <<
|
ss << "data: " <<
|
||||||
|
|
@ -1230,6 +1445,29 @@ std::string format_sse(const json & data) {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string format_anthropic_sse(const json & data) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
|
||||||
|
auto send_event = [&ss](const json & event_obj) {
|
||||||
|
if (event_obj.contains("event") && event_obj.contains("data")) {
|
||||||
|
ss << "event: " << event_obj.at("event").get<std::string>() << "\n";
|
||||||
|
ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n";
|
||||||
|
} else {
|
||||||
|
ss << "data: " << safe_json_to_str(event_obj) << "\n\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & event : data) {
|
||||||
|
send_event(event);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
send_event(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
bool is_valid_utf8(const std::string & str) {
|
bool is_valid_utf8(const std::string & str) {
|
||||||
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
||||||
const unsigned char* end = bytes + str.length();
|
const unsigned char* end = bytes + str.length();
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,9 @@ json oaicompat_chat_params_parse(
|
||||||
const oaicompat_parser_options & opt,
|
const oaicompat_parser_options & opt,
|
||||||
std::vector<raw_buffer> & out_files);
|
std::vector<raw_buffer> & out_files);
|
||||||
|
|
||||||
|
// convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||||
|
json convert_anthropic_to_oai(const json & body);
|
||||||
|
|
||||||
// TODO: move it to server-task.cpp
|
// TODO: move it to server-task.cpp
|
||||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
|
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
|
||||||
|
|
||||||
|
|
@ -320,7 +323,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
|
||||||
|
|
||||||
// format server-sent event (SSE), return the formatted string to send
|
// format server-sent event (SSE), return the formatted string to send
|
||||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||||
std::string format_sse(const json & data);
|
std::string format_oai_sse(const json & data);
|
||||||
|
|
||||||
|
// format Anthropic-style SSE with event types
|
||||||
|
std::string format_anthropic_sse(const json & data);
|
||||||
|
|
||||||
bool is_valid_utf8(const std::string & str);
|
bool is_valid_utf8(const std::string & str);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for API key in the header
|
// Check for API key in the Authorization header
|
||||||
auto auth_header = req.get_header_value("Authorization");
|
std::string req_api_key = req.get_header_value("Authorization");
|
||||||
|
if (req_api_key.empty()) {
|
||||||
|
// retry with anthropic header
|
||||||
|
req_api_key = req.get_header_value("X-Api-Key");
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the "Bearer " prefix if needed
|
||||||
std::string prefix = "Bearer ";
|
std::string prefix = "Bearer ";
|
||||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
if (req_api_key.substr(0, prefix.size()) == prefix) {
|
||||||
std::string received_api_key = auth_header.substr(prefix.size());
|
req_api_key = req_api_key.substr(prefix.size());
|
||||||
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
|
}
|
||||||
return true; // API key is valid
|
|
||||||
}
|
// validate the API key
|
||||||
|
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
|
||||||
|
return true; // API key is valid
|
||||||
}
|
}
|
||||||
|
|
||||||
// API key is invalid or not provided
|
// API key is invalid or not provided
|
||||||
|
|
|
||||||
|
|
@ -565,15 +565,17 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
|
||||||
// server_task_result_cmpl_final
|
// server_task_result_cmpl_final
|
||||||
//
|
//
|
||||||
json server_task_result_cmpl_final::to_json() {
|
json server_task_result_cmpl_final::to_json() {
|
||||||
switch (oaicompat) {
|
switch (res_type) {
|
||||||
case OAICOMPAT_TYPE_NONE:
|
case TASK_RESPONSE_TYPE_NONE:
|
||||||
return to_json_non_oaicompat();
|
return to_json_non_oaicompat();
|
||||||
case OAICOMPAT_TYPE_COMPLETION:
|
case TASK_RESPONSE_TYPE_OAI_CMPL:
|
||||||
return to_json_oaicompat();
|
return to_json_oaicompat();
|
||||||
case OAICOMPAT_TYPE_CHAT:
|
case TASK_RESPONSE_TYPE_OAI_CHAT:
|
||||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
||||||
|
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||||
|
return stream ? to_json_anthropic_stream() : to_json_anthropic();
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
GGML_ASSERT(false && "Invalid task_response_type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -768,19 +770,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
||||||
return deltas;
|
return deltas;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json server_task_result_cmpl_final::to_json_anthropic() {
|
||||||
|
std::string stop_reason = "max_tokens";
|
||||||
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
|
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||||
|
}
|
||||||
|
|
||||||
|
json content_blocks = json::array();
|
||||||
|
|
||||||
|
common_chat_msg msg;
|
||||||
|
if (!oaicompat_msg.empty()) {
|
||||||
|
msg = oaicompat_msg;
|
||||||
|
} else {
|
||||||
|
msg.role = "assistant";
|
||||||
|
msg.content = content;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!msg.content.empty()) {
|
||||||
|
content_blocks.push_back({
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", msg.content}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & tool_call : msg.tool_calls) {
|
||||||
|
json tool_use_block = {
|
||||||
|
{"type", "tool_use"},
|
||||||
|
{"id", tool_call.id},
|
||||||
|
{"name", tool_call.name}
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
tool_use_block["input"] = json::parse(tool_call.arguments);
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
tool_use_block["input"] = json::object();
|
||||||
|
}
|
||||||
|
|
||||||
|
content_blocks.push_back(tool_use_block);
|
||||||
|
}
|
||||||
|
|
||||||
|
json res = {
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"type", "message"},
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", content_blocks},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"stop_reason", stop_reason},
|
||||||
|
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||||
|
{"usage", {
|
||||||
|
{"input_tokens", n_prompt_tokens},
|
||||||
|
{"output_tokens", n_decoded}
|
||||||
|
}}
|
||||||
|
};
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||||
|
json events = json::array();
|
||||||
|
|
||||||
|
std::string stop_reason = "max_tokens";
|
||||||
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
|
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_text = !oaicompat_msg.content.empty();
|
||||||
|
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||||
|
|
||||||
|
bool text_block_started = false;
|
||||||
|
std::unordered_set<size_t> tool_calls_started;
|
||||||
|
|
||||||
|
for (const auto & diff : oaicompat_msg_diffs) {
|
||||||
|
if (!diff.content_delta.empty()) {
|
||||||
|
if (!text_block_started) {
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_start"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_start"},
|
||||||
|
{"index", 0},
|
||||||
|
{"content_block", {
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", ""}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
text_block_started = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_delta"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_delta"},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", {
|
||||||
|
{"type", "text_delta"},
|
||||||
|
{"text", diff.content_delta}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff.tool_call_index != std::string::npos) {
|
||||||
|
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||||
|
|
||||||
|
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||||
|
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||||
|
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_start"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_start"},
|
||||||
|
{"index", content_block_index},
|
||||||
|
{"content_block", {
|
||||||
|
{"type", "tool_use"},
|
||||||
|
{"id", full_tool_call.id},
|
||||||
|
{"name", full_tool_call.name}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
tool_calls_started.insert(diff.tool_call_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!diff.tool_call_delta.arguments.empty()) {
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_delta"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_delta"},
|
||||||
|
{"index", content_block_index},
|
||||||
|
{"delta", {
|
||||||
|
{"type", "input_json_delta"},
|
||||||
|
{"partial_json", diff.tool_call_delta.arguments}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_text) {
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_stop"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_stop"},
|
||||||
|
{"index", 0}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||||
|
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_stop"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_stop"},
|
||||||
|
{"index", content_block_index}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back({
|
||||||
|
{"event", "message_delta"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "message_delta"},
|
||||||
|
{"delta", {
|
||||||
|
{"stop_reason", stop_reason},
|
||||||
|
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
|
||||||
|
}},
|
||||||
|
{"usage", {
|
||||||
|
{"output_tokens", n_decoded}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
|
||||||
|
events.push_back({
|
||||||
|
{"event", "message_stop"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "message_stop"}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
|
||||||
|
return events;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// server_task_result_cmpl_partial
|
// server_task_result_cmpl_partial
|
||||||
//
|
//
|
||||||
json server_task_result_cmpl_partial::to_json() {
|
json server_task_result_cmpl_partial::to_json() {
|
||||||
switch (oaicompat) {
|
switch (res_type) {
|
||||||
case OAICOMPAT_TYPE_NONE:
|
case TASK_RESPONSE_TYPE_NONE:
|
||||||
return to_json_non_oaicompat();
|
return to_json_non_oaicompat();
|
||||||
case OAICOMPAT_TYPE_COMPLETION:
|
case TASK_RESPONSE_TYPE_OAI_CMPL:
|
||||||
return to_json_oaicompat();
|
return to_json_oaicompat();
|
||||||
case OAICOMPAT_TYPE_CHAT:
|
case TASK_RESPONSE_TYPE_OAI_CHAT:
|
||||||
return to_json_oaicompat_chat();
|
return to_json_oaicompat_chat();
|
||||||
|
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||||
|
return to_json_anthropic();
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
GGML_ASSERT(false && "Invalid task_response_type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -905,7 +1091,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
||||||
// server_task_result_embd
|
// server_task_result_embd
|
||||||
//
|
//
|
||||||
json server_task_result_embd::to_json() {
|
json server_task_result_embd::to_json() {
|
||||||
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
return res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||||
? to_json_oaicompat()
|
? to_json_oaicompat()
|
||||||
: to_json_non_oaicompat();
|
: to_json_non_oaicompat();
|
||||||
}
|
}
|
||||||
|
|
@ -936,6 +1122,102 @@ json server_task_result_rerank::to_json() {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json server_task_result_cmpl_partial::to_json_anthropic() {
|
||||||
|
json events = json::array();
|
||||||
|
bool first = (n_decoded == 1);
|
||||||
|
static bool text_block_started = false;
|
||||||
|
|
||||||
|
if (first) {
|
||||||
|
text_block_started = false;
|
||||||
|
|
||||||
|
events.push_back({
|
||||||
|
{"event", "message_start"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "message_start"},
|
||||||
|
{"message", {
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"type", "message"},
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", json::array()},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"stop_reason", nullptr},
|
||||||
|
{"stop_sequence", nullptr},
|
||||||
|
{"usage", {
|
||||||
|
{"input_tokens", n_prompt_tokens},
|
||||||
|
{"output_tokens", 0}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & diff : oaicompat_msg_diffs) {
|
||||||
|
if (!diff.content_delta.empty()) {
|
||||||
|
if (!text_block_started) {
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_start"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_start"},
|
||||||
|
{"index", 0},
|
||||||
|
{"content_block", {
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", ""}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
text_block_started = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_delta"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_delta"},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", {
|
||||||
|
{"type", "text_delta"},
|
||||||
|
{"text", diff.content_delta}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff.tool_call_index != std::string::npos) {
|
||||||
|
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||||
|
|
||||||
|
if (!diff.tool_call_delta.name.empty()) {
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_start"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_start"},
|
||||||
|
{"index", content_block_index},
|
||||||
|
{"content_block", {
|
||||||
|
{"type", "tool_use"},
|
||||||
|
{"id", diff.tool_call_delta.id},
|
||||||
|
{"name", diff.tool_call_delta.name}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!diff.tool_call_delta.arguments.empty()) {
|
||||||
|
events.push_back({
|
||||||
|
{"event", "content_block_delta"},
|
||||||
|
{"data", {
|
||||||
|
{"type", "content_block_delta"},
|
||||||
|
{"index", content_block_index},
|
||||||
|
{"delta", {
|
||||||
|
{"type", "input_json_delta"},
|
||||||
|
{"partial_json", diff.tool_call_delta.arguments}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return events;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// server_task_result_error
|
// server_task_result_error
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,12 @@ enum server_task_type {
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
|
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
|
||||||
enum oaicompat_type {
|
enum task_response_type {
|
||||||
OAICOMPAT_TYPE_NONE,
|
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
|
||||||
OAICOMPAT_TYPE_CHAT,
|
TASK_RESPONSE_TYPE_OAI_CHAT,
|
||||||
OAICOMPAT_TYPE_COMPLETION,
|
TASK_RESPONSE_TYPE_OAI_CMPL,
|
||||||
OAICOMPAT_TYPE_EMBEDDING,
|
TASK_RESPONSE_TYPE_OAI_EMBD,
|
||||||
|
TASK_RESPONSE_TYPE_ANTHROPIC,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
|
|
@ -66,9 +67,9 @@ struct task_params {
|
||||||
struct common_params_sampling sampling;
|
struct common_params_sampling sampling;
|
||||||
struct common_params_speculative speculative;
|
struct common_params_speculative speculative;
|
||||||
|
|
||||||
// OAI-compat fields
|
// response formatting
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
common_chat_syntax oaicompat_chat_syntax;
|
common_chat_syntax oaicompat_chat_syntax;
|
||||||
|
|
@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
|
|
||||||
task_params generation_params;
|
task_params generation_params;
|
||||||
|
|
||||||
// OAI-compat fields
|
// response formatting
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
common_chat_msg oaicompat_msg;
|
common_chat_msg oaicompat_msg;
|
||||||
|
|
||||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||||
|
|
||||||
|
|
@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
json to_json_oaicompat_chat();
|
json to_json_oaicompat_chat();
|
||||||
|
|
||||||
json to_json_oaicompat_chat_stream();
|
json to_json_oaicompat_chat_stream();
|
||||||
|
|
||||||
|
json to_json_anthropic();
|
||||||
|
|
||||||
|
json to_json_anthropic_stream();
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_partial : server_task_result {
|
struct server_task_result_cmpl_partial : server_task_result {
|
||||||
|
|
@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
result_timings timings;
|
result_timings timings;
|
||||||
result_prompt_progress progress;
|
result_prompt_progress progress;
|
||||||
|
|
||||||
// OAI-compat fields
|
// response formatting
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
|
|
@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
json to_json_oaicompat();
|
json to_json_oaicompat();
|
||||||
|
|
||||||
json to_json_oaicompat_chat();
|
json to_json_oaicompat_chat();
|
||||||
|
|
||||||
|
json to_json_anthropic();
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_embd : server_task_result {
|
struct server_task_result_embd : server_task_result {
|
||||||
|
|
@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result {
|
||||||
|
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
||||||
// OAI-compat fields
|
// response formatting
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
|
|
||||||
|
|
@ -1256,7 +1256,7 @@ struct server_context {
|
||||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||||
|
|
||||||
res->verbose = slot.task->params.verbose;
|
res->verbose = slot.task->params.verbose;
|
||||||
res->oaicompat = slot.task->params.oaicompat;
|
res->res_type = slot.task->params.res_type;
|
||||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||||
|
|
||||||
|
|
@ -1298,7 +1298,7 @@ struct server_context {
|
||||||
res->verbose = slot.task->params.verbose;
|
res->verbose = slot.task->params.verbose;
|
||||||
res->stream = slot.task->params.stream;
|
res->stream = slot.task->params.stream;
|
||||||
res->include_usage = slot.task->params.include_usage;
|
res->include_usage = slot.task->params.include_usage;
|
||||||
res->oaicompat = slot.task->params.oaicompat;
|
res->res_type = slot.task->params.res_type;
|
||||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||||
|
|
@ -1329,7 +1329,7 @@ struct server_context {
|
||||||
res->id = slot.task->id;
|
res->id = slot.task->id;
|
||||||
res->index = slot.task->index;
|
res->index = slot.task->index;
|
||||||
res->n_tokens = slot.task->n_tokens();
|
res->n_tokens = slot.task->n_tokens();
|
||||||
res->oaicompat = slot.task->params.oaicompat;
|
res->res_type = slot.task->params.res_type;
|
||||||
|
|
||||||
const int n_embd = llama_model_n_embd(model);
|
const int n_embd = llama_model_n_embd(model);
|
||||||
|
|
||||||
|
|
@ -2954,7 +2954,7 @@ public:
|
||||||
data,
|
data,
|
||||||
files,
|
files,
|
||||||
req.should_stop,
|
req.should_stop,
|
||||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
|
||||||
};
|
};
|
||||||
|
|
||||||
server_http_context::handler_t post_completions = [this](const server_http_req & req) {
|
server_http_context::handler_t post_completions = [this](const server_http_req & req) {
|
||||||
|
|
@ -2965,7 +2965,7 @@ public:
|
||||||
body,
|
body,
|
||||||
files,
|
files,
|
||||||
req.should_stop,
|
req.should_stop,
|
||||||
OAICOMPAT_TYPE_NONE);
|
TASK_RESPONSE_TYPE_NONE);
|
||||||
};
|
};
|
||||||
|
|
||||||
server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
|
server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
|
||||||
|
|
@ -2976,7 +2976,7 @@ public:
|
||||||
body,
|
body,
|
||||||
files,
|
files,
|
||||||
req.should_stop,
|
req.should_stop,
|
||||||
OAICOMPAT_TYPE_COMPLETION);
|
TASK_RESPONSE_TYPE_OAI_CMPL);
|
||||||
};
|
};
|
||||||
|
|
||||||
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
|
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
|
||||||
|
|
@ -2991,7 +2991,38 @@ public:
|
||||||
body_parsed,
|
body_parsed,
|
||||||
files,
|
files,
|
||||||
req.should_stop,
|
req.should_stop,
|
||||||
OAICOMPAT_TYPE_CHAT);
|
TASK_RESPONSE_TYPE_OAI_CHAT);
|
||||||
|
};
|
||||||
|
|
||||||
|
server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) {
|
||||||
|
std::vector<raw_buffer> files;
|
||||||
|
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||||
|
json body_parsed = oaicompat_chat_params_parse(
|
||||||
|
body,
|
||||||
|
ctx_server.oai_parser_opt,
|
||||||
|
files);
|
||||||
|
return handle_completions_impl(
|
||||||
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
|
body_parsed,
|
||||||
|
files,
|
||||||
|
req.should_stop,
|
||||||
|
TASK_RESPONSE_TYPE_ANTHROPIC);
|
||||||
|
};
|
||||||
|
|
||||||
|
server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) {
|
||||||
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||||
|
std::vector<raw_buffer> files;
|
||||||
|
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||||
|
json body_parsed = oaicompat_chat_params_parse(
|
||||||
|
body,
|
||||||
|
ctx_server.oai_parser_opt,
|
||||||
|
files);
|
||||||
|
|
||||||
|
json prompt = body_parsed.at("prompt");
|
||||||
|
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
|
||||||
|
|
||||||
|
res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
|
||||||
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
// same with handle_chat_completions, but without inference part
|
// same with handle_chat_completions, but without inference part
|
||||||
|
|
@ -3110,11 +3141,11 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
|
server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
|
||||||
return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE);
|
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
|
||||||
};
|
};
|
||||||
|
|
||||||
server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
|
server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
|
||||||
return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING);
|
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
|
||||||
};
|
};
|
||||||
|
|
||||||
server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
|
server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
|
||||||
|
|
@ -3394,7 +3425,7 @@ private:
|
||||||
const json & data,
|
const json & data,
|
||||||
const std::vector<raw_buffer> & files,
|
const std::vector<raw_buffer> & files,
|
||||||
const std::function<bool()> & should_stop,
|
const std::function<bool()> & should_stop,
|
||||||
oaicompat_type oaicompat) {
|
task_response_type res_type) {
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||||
|
|
@ -3411,7 +3442,7 @@ private:
|
||||||
// process prompt
|
// process prompt
|
||||||
std::vector<server_tokens> inputs;
|
std::vector<server_tokens> inputs;
|
||||||
|
|
||||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
|
||||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -3433,8 +3464,8 @@ private:
|
||||||
task.id_slot = json_value(data, "id_slot", -1);
|
task.id_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.res_type = res_type;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
|
|
@ -3484,10 +3515,14 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
// next responses are streamed
|
// next responses are streamed
|
||||||
res->data = format_sse(first_result->to_json()); // to be sent immediately
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
|
res->data = format_anthropic_sse(first_result->to_json());
|
||||||
|
} else {
|
||||||
|
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
||||||
|
}
|
||||||
res->status = 200;
|
res->status = 200;
|
||||||
res->content_type = "text/event-stream";
|
res->content_type = "text/event-stream";
|
||||||
res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool {
|
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
||||||
if (should_stop()) {
|
if (should_stop()) {
|
||||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||||
return false; // should_stop condition met
|
return false; // should_stop condition met
|
||||||
|
|
@ -3504,7 +3539,10 @@ private:
|
||||||
|
|
||||||
// check if there is more data
|
// check if there is more data
|
||||||
if (!rd.has_next()) {
|
if (!rd.has_next()) {
|
||||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
|
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||||
|
output = "";
|
||||||
|
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||||
output = "data: [DONE]\n\n";
|
output = "data: [DONE]\n\n";
|
||||||
} else {
|
} else {
|
||||||
output = "";
|
output = "";
|
||||||
|
|
@ -3523,7 +3561,14 @@ private:
|
||||||
// send the results
|
// send the results
|
||||||
json res_json = result->to_json();
|
json res_json = result->to_json();
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
output = format_sse(json {{ "error", res_json }});
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
|
output = format_anthropic_sse({
|
||||||
|
{"event", "error"},
|
||||||
|
{"data", res_json},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
output = format_oai_sse(json {{ "error", res_json }});
|
||||||
|
}
|
||||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||||
return false; // terminate on error
|
return false; // terminate on error
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -3531,7 +3576,11 @@ private:
|
||||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||||
);
|
);
|
||||||
output = format_sse(res_json);
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
|
output = format_anthropic_sse(res_json);
|
||||||
|
} else {
|
||||||
|
output = format_oai_sse(res_json);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// has next data, continue
|
// has next data, continue
|
||||||
|
|
@ -3639,14 +3688,14 @@ private:
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) {
|
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
|
||||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||||
if (!ctx_server.params_base.embedding) {
|
if (!ctx_server.params_base.embedding) {
|
||||||
res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||||
res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
@ -3658,7 +3707,7 @@ private:
|
||||||
if (body.count("input") != 0) {
|
if (body.count("input") != 0) {
|
||||||
prompt = body.at("input");
|
prompt = body.at("input");
|
||||||
} else if (body.contains("content")) {
|
} else if (body.contains("content")) {
|
||||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
|
||||||
prompt = body.at("content");
|
prompt = body.at("content");
|
||||||
} else {
|
} else {
|
||||||
res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
|
@ -3706,7 +3755,7 @@ private:
|
||||||
task.tokens = std::move(tokenized_prompts[i]);
|
task.tokens = std::move(tokenized_prompts[i]);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.res_type = res_type;
|
||||||
task.params.embd_normalize = embd_normalize;
|
task.params.embd_normalize = embd_normalize;
|
||||||
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
|
|
@ -3731,7 +3780,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||||
: json(responses);
|
: json(responses);
|
||||||
res->ok(root);
|
res->ok(root);
|
||||||
|
|
@ -3903,6 +3952,8 @@ int main(int argc, char ** argv, char ** envp) {
|
||||||
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||||
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
||||||
|
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||||
|
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
|
||||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||||
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
||||||
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
||||||
|
|
|
||||||
|
|
@ -13,3 +13,9 @@ def stop_server_after_each_test():
|
||||||
) # copy the set to prevent 'Set changed size during iteration'
|
) # copy the set to prevent 'Set changed size during iteration'
|
||||||
for server in instances:
|
for server in instances:
|
||||||
server.stop()
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def do_something():
|
||||||
|
# this will be run once per test session, before any tests
|
||||||
|
ServerPreset.load_all()
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,6 @@ from utils import *
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def do_something():
|
|
||||||
# this will be run once per test session, before any tests
|
|
||||||
ServerPreset.load_all()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,807 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server: ServerProcess
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_image_base64() -> str:
|
||||||
|
"""Get a test image in base64 format"""
|
||||||
|
# Use the same test image as test_vision_api.py
|
||||||
|
IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||||
|
response = requests.get(IMG_URL)
|
||||||
|
response.raise_for_status()
|
||||||
|
return base64.b64encode(response.content).decode("utf-8")
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
server.model_alias = "tinyllama-2-anthropic"
|
||||||
|
server.server_port = 8082
|
||||||
|
server.n_slots = 1
|
||||||
|
server.n_ctx = 8192
|
||||||
|
server.n_batch = 2048
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vision_server():
|
||||||
|
"""Separate fixture for vision tests that require multimodal support"""
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinygemma3()
|
||||||
|
server.offline = False # Allow downloading the model
|
||||||
|
server.model_alias = "tinygemma3-anthropic"
|
||||||
|
server.server_port = 8083 # Different port to avoid conflicts
|
||||||
|
server.n_slots = 1
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
# Basic message tests
|
||||||
|
|
||||||
|
def test_anthropic_messages_basic():
|
||||||
|
"""Test basic Anthropic messages endpoint"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Say hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200, f"Expected 200, got {res.status_code}"
|
||||||
|
assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}"
|
||||||
|
assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}"
|
||||||
|
assert "content" in res.body, "Missing 'content' field"
|
||||||
|
assert isinstance(res.body["content"], list), "Content should be an array"
|
||||||
|
assert len(res.body["content"]) > 0, "Content array should not be empty"
|
||||||
|
assert res.body["content"][0]["type"] == "text", "First content block should be text"
|
||||||
|
assert "text" in res.body["content"][0], "Text content block missing 'text' field"
|
||||||
|
assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}"
|
||||||
|
assert "usage" in res.body, "Missing 'usage' field"
|
||||||
|
assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens"
|
||||||
|
assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens"
|
||||||
|
assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer"
|
||||||
|
assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer"
|
||||||
|
assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"
|
||||||
|
# Anthropic API should NOT include timings
|
||||||
|
assert "timings" not in res.body, "Anthropic API should not include timings field"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_messages_with_system():
|
||||||
|
"""Test messages with system prompt"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"system": "You are a helpful assistant.",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_messages_multipart_content():
|
||||||
|
"""Test messages with multipart content blocks"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is"},
|
||||||
|
{"type": "text", "text": " the answer?"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_messages_conversation():
|
||||||
|
"""Test multi-turn conversation"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you?"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
# Streaming tests
|
||||||
|
|
||||||
|
def test_anthropic_messages_streaming():
|
||||||
|
"""Test streaming messages"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 30,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Say hello"}
|
||||||
|
],
|
||||||
|
"stream": True
|
||||||
|
})
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for data in res:
|
||||||
|
# Each event should have type and other fields
|
||||||
|
assert "type" in data, f"Missing 'type' in event: {data}"
|
||||||
|
events.append(data)
|
||||||
|
|
||||||
|
# Verify event sequence
|
||||||
|
event_types = [e["type"] for e in events]
|
||||||
|
assert "message_start" in event_types, "Missing message_start event"
|
||||||
|
assert "content_block_start" in event_types, "Missing content_block_start event"
|
||||||
|
assert "content_block_delta" in event_types, "Missing content_block_delta event"
|
||||||
|
assert "content_block_stop" in event_types, "Missing content_block_stop event"
|
||||||
|
assert "message_delta" in event_types, "Missing message_delta event"
|
||||||
|
assert "message_stop" in event_types, "Missing message_stop event"
|
||||||
|
|
||||||
|
# Check message_start structure
|
||||||
|
message_start = next(e for e in events if e["type"] == "message_start")
|
||||||
|
assert "message" in message_start, "message_start missing 'message' field"
|
||||||
|
assert message_start["message"]["type"] == "message"
|
||||||
|
assert message_start["message"]["role"] == "assistant"
|
||||||
|
assert message_start["message"]["content"] == []
|
||||||
|
assert "usage" in message_start["message"]
|
||||||
|
assert message_start["message"]["usage"]["input_tokens"] > 0
|
||||||
|
|
||||||
|
# Check content_block_start
|
||||||
|
block_start = next(e for e in events if e["type"] == "content_block_start")
|
||||||
|
assert "index" in block_start, "content_block_start missing 'index'"
|
||||||
|
assert block_start["index"] == 0, "First content block should be at index 0"
|
||||||
|
assert "content_block" in block_start
|
||||||
|
assert block_start["content_block"]["type"] == "text"
|
||||||
|
|
||||||
|
# Check content_block_delta
|
||||||
|
deltas = [e for e in events if e["type"] == "content_block_delta"]
|
||||||
|
assert len(deltas) > 0, "Should have at least one content_block_delta"
|
||||||
|
for delta in deltas:
|
||||||
|
assert "index" in delta
|
||||||
|
assert "delta" in delta
|
||||||
|
assert delta["delta"]["type"] == "text_delta"
|
||||||
|
assert "text" in delta["delta"]
|
||||||
|
|
||||||
|
# Check content_block_stop
|
||||||
|
block_stop = next(e for e in events if e["type"] == "content_block_stop")
|
||||||
|
assert "index" in block_stop
|
||||||
|
assert block_stop["index"] == 0
|
||||||
|
|
||||||
|
# Check message_delta
|
||||||
|
message_delta = next(e for e in events if e["type"] == "message_delta")
|
||||||
|
assert "delta" in message_delta
|
||||||
|
assert "stop_reason" in message_delta["delta"]
|
||||||
|
assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"]
|
||||||
|
assert "usage" in message_delta
|
||||||
|
assert message_delta["usage"]["output_tokens"] > 0
|
||||||
|
|
||||||
|
# Check message_stop
|
||||||
|
message_stop = next(e for e in events if e["type"] == "message_stop")
|
||||||
|
# message_stop should NOT have timings for Anthropic API
|
||||||
|
assert "timings" not in message_stop, "Anthropic streaming should not include timings"
|
||||||
|
|
||||||
|
|
||||||
|
# Token counting tests
|
||||||
|
|
||||||
|
def test_anthropic_count_tokens():
|
||||||
|
"""Test token counting endpoint"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||||
|
"model": "test",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello world"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "input_tokens" in res.body
|
||||||
|
assert isinstance(res.body["input_tokens"], int)
|
||||||
|
assert res.body["input_tokens"] > 0
|
||||||
|
# Should only have input_tokens, no other fields
|
||||||
|
assert "output_tokens" not in res.body
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_count_tokens_with_system():
|
||||||
|
"""Test token counting with system prompt"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||||
|
"model": "test",
|
||||||
|
"system": "You are a helpful assistant.",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["input_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_count_tokens_no_max_tokens():
|
||||||
|
"""Test that count_tokens doesn't require max_tokens"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# max_tokens is NOT required for count_tokens
|
||||||
|
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||||
|
"model": "test",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "input_tokens" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
# Tool use tests
|
||||||
|
|
||||||
|
def test_anthropic_tool_use_basic():
|
||||||
|
"""Test basic tool use"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 200,
|
||||||
|
"tools": [{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a location",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather in Paris?"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
|
||||||
|
# Check if model used the tool (it might not always, depending on the model)
|
||||||
|
content_types = [block.get("type") for block in res.body["content"]]
|
||||||
|
|
||||||
|
if "tool_use" in content_types:
|
||||||
|
# Model used the tool
|
||||||
|
assert res.body["stop_reason"] == "tool_use"
|
||||||
|
|
||||||
|
# Find the tool_use block
|
||||||
|
tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use")
|
||||||
|
assert "id" in tool_block
|
||||||
|
assert "name" in tool_block
|
||||||
|
assert tool_block["name"] == "get_weather"
|
||||||
|
assert "input" in tool_block
|
||||||
|
assert isinstance(tool_block["input"], dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_tool_result():
|
||||||
|
"""Test sending tool results back
|
||||||
|
|
||||||
|
This test verifies that tool_result blocks are properly converted to
|
||||||
|
role="tool" messages internally. Without proper conversion, this would
|
||||||
|
fail with a 500 error: "unsupported content[].type" because tool_result
|
||||||
|
blocks would remain in the user message content array.
|
||||||
|
"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 100,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "test123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"location": "Paris"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "test123",
|
||||||
|
"content": "The weather is sunny, 25°C"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# This would be 500 with the old bug where tool_result blocks weren't converted
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
# Model should respond to the tool result
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
assert res.body["content"][0]["type"] == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_tool_result_with_text():
|
||||||
|
"""Test tool result mixed with text content
|
||||||
|
|
||||||
|
This tests the edge case where a user message contains both text and
|
||||||
|
tool_result blocks. The server must properly split these into separate
|
||||||
|
messages: a user message with text, followed by tool messages.
|
||||||
|
Without proper handling, this would fail with 500: "unsupported content[].type"
|
||||||
|
"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 100,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "tool_1",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"location": "Paris"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Here are the results:"},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "tool_1",
|
||||||
|
"content": "Sunny, 25°C"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_tool_result_error():
|
||||||
|
"""Test tool result with error flag"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 100,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Get the weather"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "test123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"location": "InvalidCity"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "test123",
|
||||||
|
"is_error": True,
|
||||||
|
"content": "City not found"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_tool_streaming():
|
||||||
|
"""Test streaming with tool use"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 200,
|
||||||
|
"stream": True,
|
||||||
|
"tools": [{
|
||||||
|
"name": "calculator",
|
||||||
|
"description": "Calculate math",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Calculate 2+2"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for data in res:
|
||||||
|
events.append(data)
|
||||||
|
|
||||||
|
event_types = [e["type"] for e in events]
|
||||||
|
|
||||||
|
# Should have basic events
|
||||||
|
assert "message_start" in event_types
|
||||||
|
assert "message_stop" in event_types
|
||||||
|
|
||||||
|
# If tool was used, check for proper tool streaming
|
||||||
|
if any(e.get("type") == "content_block_start" and
|
||||||
|
e.get("content_block", {}).get("type") == "tool_use"
|
||||||
|
for e in events):
|
||||||
|
# Find tool use block start
|
||||||
|
tool_starts = [e for e in events if
|
||||||
|
e.get("type") == "content_block_start" and
|
||||||
|
e.get("content_block", {}).get("type") == "tool_use"]
|
||||||
|
|
||||||
|
assert len(tool_starts) > 0, "Should have tool_use content_block_start"
|
||||||
|
|
||||||
|
# Check index is correct (should be 0 if no text, 1 if there's text)
|
||||||
|
tool_start = tool_starts[0]
|
||||||
|
assert "index" in tool_start
|
||||||
|
assert tool_start["content_block"]["type"] == "tool_use"
|
||||||
|
assert "name" in tool_start["content_block"]
|
||||||
|
|
||||||
|
|
||||||
|
# Vision/multimodal tests
|
||||||
|
|
||||||
|
def test_anthropic_vision_format_accepted():
|
||||||
|
"""Test that Anthropic vision format is accepted (format validation only)"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Small 1x1 red PNG image in base64
|
||||||
|
red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": red_pixel_png
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is this?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Server accepts the format but tinyllama doesn't support images
|
||||||
|
# So it should return 500 with clear error message about missing mmproj
|
||||||
|
assert res.status_code == 500
|
||||||
|
assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_vision_base64_with_multimodal_model(vision_server):
|
||||||
|
"""Test vision with base64 image using Anthropic format with multimodal model"""
|
||||||
|
global server
|
||||||
|
server = vision_server
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Get test image in base64 format
|
||||||
|
image_base64 = get_test_image_base64()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": image_base64
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is this:\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}"
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
assert res.body["content"][0]["type"] == "text"
|
||||||
|
# The model should generate some response about the image
|
||||||
|
assert len(res.body["content"][0]["text"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Parameter tests
|
||||||
|
|
||||||
|
def test_anthropic_stop_sequences():
|
||||||
|
"""Test stop_sequences parameter"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stop_sequences": ["\n", "END"],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Count to 10"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_temperature():
|
||||||
|
"""Test temperature parameter"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_top_p():
|
||||||
|
"""Test top_p parameter"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_top_k():
|
||||||
|
"""Test top_k parameter (llama.cpp specific)"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"top_k": 40,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
# Error handling tests
|
||||||
|
|
||||||
|
def test_anthropic_missing_messages():
|
||||||
|
"""Test error when messages are missing"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50
|
||||||
|
# missing "messages" field
|
||||||
|
})
|
||||||
|
|
||||||
|
# Should return an error (400 or 500)
|
||||||
|
assert res.status_code >= 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_empty_messages():
|
||||||
|
"""Test permissive handling of empty messages array"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"messages": []
|
||||||
|
})
|
||||||
|
|
||||||
|
# Server is permissive and accepts empty messages (provides defaults)
|
||||||
|
# This matches the permissive validation design choice
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
# Content block index tests
|
||||||
|
|
||||||
|
def test_anthropic_streaming_content_block_indices():
|
||||||
|
"""Test that content block indices are correct in streaming"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Request that might produce both text and tool use
|
||||||
|
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 200,
|
||||||
|
"stream": True,
|
||||||
|
"tools": [{
|
||||||
|
"name": "test_tool",
|
||||||
|
"description": "A test tool",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"param": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["param"]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Use the test tool"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for data in res:
|
||||||
|
events.append(data)
|
||||||
|
|
||||||
|
# Check content_block_start events have sequential indices
|
||||||
|
block_starts = [e for e in events if e.get("type") == "content_block_start"]
|
||||||
|
if len(block_starts) > 1:
|
||||||
|
# If there are multiple blocks, indices should be sequential
|
||||||
|
indices = [e["index"] for e in block_starts]
|
||||||
|
expected_indices = list(range(len(block_starts)))
|
||||||
|
assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}"
|
||||||
|
|
||||||
|
# Check content_block_stop events match the starts
|
||||||
|
block_stops = [e for e in events if e.get("type") == "content_block_stop"]
|
||||||
|
start_indices = set(e["index"] for e in block_starts)
|
||||||
|
stop_indices = set(e["index"] for e in block_stops)
|
||||||
|
assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices"
|
||||||
|
|
||||||
|
|
||||||
|
# Extended features tests
|
||||||
|
|
||||||
|
def test_anthropic_thinking():
|
||||||
|
"""Test extended thinking parameter"""
|
||||||
|
server.jinja = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 100,
|
||||||
|
"thinking": {
|
||||||
|
"type": "enabled",
|
||||||
|
"budget_tokens": 50
|
||||||
|
},
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_metadata():
|
||||||
|
"""Test metadata parameter"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"metadata": {
|
||||||
|
"user_id": "test_user_123"
|
||||||
|
},
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
# Compatibility tests
|
||||||
|
|
||||||
|
def test_anthropic_vs_openai_different_response_format():
|
||||||
|
"""Verify Anthropic format is different from OpenAI format"""
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Make OpenAI request
|
||||||
|
openai_res = server.make_request("POST", "/v1/chat/completions", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Make Anthropic request
|
||||||
|
anthropic_res = server.make_request("POST", "/v1/messages", data={
|
||||||
|
"model": "test",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert openai_res.status_code == 200
|
||||||
|
assert anthropic_res.status_code == 200
|
||||||
|
|
||||||
|
# OpenAI has "object", Anthropic has "type"
|
||||||
|
assert "object" in openai_res.body
|
||||||
|
assert "type" in anthropic_res.body
|
||||||
|
assert openai_res.body["object"] == "chat.completion"
|
||||||
|
assert anthropic_res.body["type"] == "message"
|
||||||
|
|
||||||
|
# OpenAI has "choices", Anthropic has "content"
|
||||||
|
assert "choices" in openai_res.body
|
||||||
|
assert "content" in anthropic_res.body
|
||||||
|
|
||||||
|
# Different usage field names
|
||||||
|
assert "prompt_tokens" in openai_res.body["usage"]
|
||||||
|
assert "input_tokens" in anthropic_res.body["usage"]
|
||||||
|
assert "completion_tokens" in openai_res.body["usage"]
|
||||||
|
assert "output_tokens" in anthropic_res.body["usage"]
|
||||||
|
|
@ -49,6 +49,19 @@ def test_correct_api_key():
|
||||||
assert "content" in res.body
|
assert "content" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
def test_correct_api_key_anthropic_header():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completions", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
}, headers={
|
||||||
|
"X-Api-Key": TEST_API_KEY,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "error" not in res.body
|
||||||
|
assert "content" in res.body
|
||||||
|
|
||||||
|
|
||||||
def test_openai_library_correct_api_key():
|
def test_openai_library_correct_api_key():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -205,6 +205,8 @@ class ServerProcess:
|
||||||
server_args.append("--no-webui")
|
server_args.append("--no-webui")
|
||||||
if self.jinja:
|
if self.jinja:
|
||||||
server_args.append("--jinja")
|
server_args.append("--jinja")
|
||||||
|
else:
|
||||||
|
server_args.append("--no-jinja")
|
||||||
if self.reasoning_format is not None:
|
if self.reasoning_format is not None:
|
||||||
server_args.extend(("--reasoning-format", self.reasoning_format))
|
server_args.extend(("--reasoning-format", self.reasoning_format))
|
||||||
if self.reasoning_budget is not None:
|
if self.reasoning_budget is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue