Merge branch 'master' into imatrix
This commit is contained in:
commit
c81f7cdf9f
|
|
@ -134,6 +134,8 @@ jobs:
|
|||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 's390x-z15' # z15 because our CI runners are on z15
|
||||
os: ubuntu-22.04-s390x
|
||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||
# - build: 'arm64'
|
||||
# os: ubuntu-22.04-arm
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ name: Update Operations Documentation
|
|||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'docs/ops.md'
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/ops.md'
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@
|
|||
/ggml/src/ggml-cuda/common.cuh @slaren
|
||||
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
|
||||
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
|
@ -187,6 +188,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
|
||||
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
||||
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
|
||||
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ if [ ! -z ${GG_BUILD_ROCM} ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DAMDGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_SYCL} ]; then
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||
return result;
|
||||
}
|
||||
|
||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int64_t>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int64_t>::max();
|
||||
|
||||
auto digit_range = [&](char from, char to) {
|
||||
out << "[";
|
||||
|
|
@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
if (has_min) {
|
||||
if (min_value < 0) {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
out << ") | [0] | [1-9] ";
|
||||
more_digits(0, decimals_left - 1);
|
||||
} else if (min_value == 0) {
|
||||
|
|
@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
}
|
||||
digit_range(c, c);
|
||||
out << " (";
|
||||
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
|
||||
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||
out << ")";
|
||||
if (c < '9') {
|
||||
out << " | ";
|
||||
|
|
@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||
} else {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||
out << ")";
|
||||
}
|
||||
return;
|
||||
|
|
@ -925,17 +925,17 @@ public:
|
|||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int min_value = std::numeric_limits<int>::min();
|
||||
int max_value = std::numeric_limits<int>::max();
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
int64_t max_value = std::numeric_limits<int64_t>::max();
|
||||
if (schema.contains("minimum")) {
|
||||
min_value = schema["minimum"].get<int>();
|
||||
min_value = schema["minimum"].get<int64_t>();
|
||||
} else if (schema.contains("exclusiveMinimum")) {
|
||||
min_value = schema["exclusiveMinimum"].get<int>() + 1;
|
||||
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
|
||||
}
|
||||
if (schema.contains("maximum")) {
|
||||
max_value = schema["maximum"].get<int>();
|
||||
max_value = schema["maximum"].get<int64_t>();
|
||||
} else if (schema.contains("exclusiveMaximum")) {
|
||||
max_value = schema["exclusiveMaximum"].get<int>() - 1;
|
||||
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
|
||||
}
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
|
|
|
|||
|
|
@ -892,8 +892,8 @@ class TextModel(ModelBase):
|
|||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
|
||||
res = "llada-moe"
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
|
||||
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
|
||||
res = "granite-docling"
|
||||
|
|
@ -8055,6 +8055,103 @@ class BailingMoeModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeV2ForCausalLM")
|
||||
class BailingMoeV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
|
||||
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_group_count(hparams["n_group"])
|
||||
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["score_function"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||
|
||||
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "mlp.experts" in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
|
||||
class GroveMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GROVEMOE
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ models = [
|
|||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
]
|
||||
|
||||
|
|
|
|||
12
docs/ops.md
12
docs/ops.md
|
|
@ -22,7 +22,7 @@ Legend:
|
|||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -42,7 +42,7 @@ Legend:
|
|||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
|
|
@ -84,7 +84,7 @@ Legend:
|
|||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -100,8 +100,8 @@ Legend:
|
|||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -111,6 +111,6 @@ Legend:
|
|||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@
|
|||
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
|
||||
"SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
|
||||
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
|
|
@ -95,6 +103,14 @@
|
|||
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
|
||||
"SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
|
||||
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -3263,27 +3263,27 @@
|
|||
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
|
|||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
|
||||
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||
|
|
|
|||
|
|
@ -307,6 +307,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
|||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||
|
|
@ -371,6 +375,14 @@ if (GGML_CPU_ALL_VARIANTS)
|
|||
else()
|
||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -598,6 +598,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
|
||||
}
|
||||
|
||||
// free the extra space at the end if the new tensor is smaller
|
||||
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||
|
||||
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
|
||||
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
|
||||
|
||||
GGML_ASSERT(parent_size >= node_size);
|
||||
|
||||
if (parent_size > node_size) {
|
||||
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
|
||||
struct buffer_address p_addr = p_hn->addr;
|
||||
p_addr.offset += node_size;
|
||||
size_t extra_size = parent_size - node_size;
|
||||
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
|
||||
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
|
||||
GGML_ASSERT(buffer_id >= 0);
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
|
|
@ -643,6 +663,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
hn->addr = p_hn->addr;
|
||||
p_hn->allocated = false; // avoid freeing the parent
|
||||
view_src_hn->allocated = false;
|
||||
ggml_gallocr_free_extra_space(galloc, node, view_src);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -650,6 +671,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
hn->buffer_id = p_hn->buffer_id;
|
||||
hn->addr = p_hn->addr;
|
||||
p_hn->allocated = false; // avoid freeing the parent
|
||||
ggml_gallocr_free_extra_space(galloc, node, parent);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -466,29 +466,45 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/s390/quants.c)
|
||||
|
||||
# TODO: Separation to determine activation of VX/VXE/VXE2
|
||||
if (${S390X_M} MATCHES "8561|8562")
|
||||
message(STATUS "z15 target")
|
||||
list(APPEND ARCH_FLAGS -march=z15)
|
||||
elseif (${S390X_M} MATCHES "3931")
|
||||
message(STATUS "z16 target")
|
||||
list(APPEND ARCH_FLAGS -march=z16)
|
||||
elseif (${S390X_M} MATCHES "9175|9176")
|
||||
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
||||
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
|
||||
message(STATUS "z17 target")
|
||||
list(APPEND ARCH_FLAGS -march=arch15)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
# for native compilation
|
||||
if (GGML_NATIVE)
|
||||
# check machine level to determine target
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
|
||||
# TODO: Separation to determine activation of VX/VXE/VXE2
|
||||
if (${S390X_M} MATCHES "8561|8562")
|
||||
message(STATUS "z15 target")
|
||||
list(APPEND ARCH_FLAGS -march=z15)
|
||||
elseif (${S390X_M} MATCHES "3931")
|
||||
message(STATUS "z16 target")
|
||||
list(APPEND ARCH_FLAGS -march=z16)
|
||||
elseif (${S390X_M} MATCHES "9175|9176")
|
||||
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
||||
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
|
||||
message(STATUS "z17 target")
|
||||
list(APPEND ARCH_FLAGS -march=arch15)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
endif()
|
||||
# for cross-compilation
|
||||
elseif(GGML_CPU_ALL_VARIANTS)
|
||||
# range through IBM z15 to z17
|
||||
# NOTE: update when a new hardware level is released
|
||||
foreach (ZHW RANGE 15 17)
|
||||
if(DEFINED GGML_INTERNAL_Z${ZHW})
|
||||
message(STATUS "z${ZHW} cross-compile target")
|
||||
list(APPEND ARCH_FLAGS -march=z${ZHW})
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if (GGML_VXE)
|
||||
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
||||
message(STATUS "VX/VXE/VXE2 enabled")
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
||||
|
|
|
|||
|
|
@ -485,8 +485,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|||
int32_t start = ith * task_per_thread;
|
||||
int32_t end = std::min((ith + 1) * task_per_thread, task_count);
|
||||
for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
|
||||
int32_t gemm_idx = compute_idx / block_size_m;
|
||||
int32_t m_idx = compute_idx % block_size_m * block_size_m;
|
||||
int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
|
||||
int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
|
||||
int32_t m_idx = block_idx_in_gemm * block_size_m;
|
||||
const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
|
||||
int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
|
||||
float wt_sum = 0.f;
|
||||
|
||||
extern __shared__ float data_topk_shared[];
|
||||
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
|
|
@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
}
|
||||
}
|
||||
|
||||
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
output_weights[k / WARP_SIZE] = max_val;
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
wt_shared_ptr[k] = max_val;
|
||||
ids[k] = max_expert;
|
||||
ids[k] = max_expert;
|
||||
if constexpr (with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
|
|
@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
|
||||
output_weights[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||
weights[i] = wt_shared_ptr[i];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
|
|
|
|||
|
|
@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
|
|||
" Prefer setting the HIP compiler directly. See README for details.")
|
||||
endif()
|
||||
else()
|
||||
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
|
||||
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
|
|
|
|||
|
|
@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
|||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||
|
||||
static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
|
||||
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
|
||||
return 0;
|
||||
}
|
||||
return cgraph->use_counts[hash_pos];
|
||||
}
|
||||
|
||||
// return true if the node's results are only used by N other nodes
|
||||
// and can be fused into their calculations.
|
||||
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
||||
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
|
||||
// check the use count against how many we're replacing
|
||||
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
|
||||
if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_UPSCALE);
|
||||
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_SCALE:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SQR:
|
||||
|
|
|
|||
|
|
@ -514,6 +514,19 @@ typedef struct {
|
|||
uint64_t nb1;
|
||||
} ggml_metal_kargs_conv_transpose_1d;
|
||||
|
||||
typedef struct {
|
||||
int32_t IC;
|
||||
int32_t IH;
|
||||
int32_t IW;
|
||||
int32_t KH;
|
||||
int32_t KW;
|
||||
int32_t OC;
|
||||
int32_t s0;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_conv_transpose_2d;
|
||||
|
||||
typedef struct {
|
||||
uint64_t ofs0;
|
||||
uint64_t ofs1;
|
||||
|
|
|
|||
|
|
@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
{
|
||||
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
||||
|
|
@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
|
||||
const int32_t IC = op->src[1]->ne[2];
|
||||
const int32_t IH = op->src[1]->ne[1];
|
||||
const int32_t IW = op->src[1]->ne[0];
|
||||
|
||||
const int32_t KH = op->src[0]->ne[1];
|
||||
const int32_t KW = op->src[0]->ne[0];
|
||||
|
||||
const int32_t OW = op->ne[0];
|
||||
const int32_t OH = op->ne[1];
|
||||
const int32_t OC = op->ne[2];
|
||||
|
||||
ggml_metal_kargs_conv_transpose_2d args = {
|
||||
/*.IC =*/ IC,
|
||||
/*.IH =*/ IH,
|
||||
/*.IW =*/ IW,
|
||||
/*.KH =*/ KH,
|
||||
/*.KW =*/ KW,
|
||||
/*.OC =*/ OC,
|
||||
/*.s0 =*/ s0,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
// Metal requires buffer size to be multiple of 16 bytes
|
||||
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
|
||||
typedef void (conv_transpose_2d_t)(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_2d(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t out_x = tgpig[0];
|
||||
const int64_t out_y = tgpig[1];
|
||||
const int64_t out_c = tgpig[2];
|
||||
|
||||
const int64_t kw = tpitg[0];
|
||||
const int64_t kh = tpitg[1];
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
||||
int64_t in_y = out_y - kh;
|
||||
|
||||
if (in_y < 0 || in_y % args.s0) continue;
|
||||
|
||||
in_y /= args.s0;
|
||||
|
||||
if (in_y >= args.IH) continue;
|
||||
|
||||
int64_t in_x = out_x - kw;
|
||||
|
||||
if (in_x < 0 || in_x % args.s0) continue;
|
||||
|
||||
in_x /= args.s0;
|
||||
|
||||
if (in_x >= args.IW) continue;
|
||||
|
||||
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
||||
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
||||
|
||||
v += (float)src0[kernel_idx] * src1[input_idx];
|
||||
}
|
||||
|
||||
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
||||
shared_sum[tid] = v;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tid == 0) {
|
||||
float total = 0.0f;
|
||||
const uint num_threads = ntg.x * ntg.y;
|
||||
for (uint i = 0; i < num_threads; i++) {
|
||||
total += shared_sum[i];
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
||||
dst_ptr[0] = total;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<float>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<half>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ set(GGML_OPENCL_KERNELS
|
|||
mul_mv_id_q8_0_f32_flat
|
||||
mul_mv_id_mxfp4_f32
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
gemm_moe_mxfp4_f32
|
||||
gemv_moe_mxfp4_f32
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
|
|
|
|||
|
|
@ -402,6 +402,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_program program_conv_2d_f32;
|
||||
cl_program program_conv_2d_f16_f32;
|
||||
cl_program program_tsembd;
|
||||
cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
|
||||
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
|
||||
cl_program program_mul_mv_id_mxfp4_f32;
|
||||
|
|
@ -452,7 +453,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_convert_block_q4_0_noshuffle;
|
||||
|
|
@ -475,6 +476,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_conv_2d_f32;
|
||||
cl_kernel kernel_conv_2d_f16_f32;
|
||||
cl_kernel kernel_timestep_embedding;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
|
||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
|
||||
cl_kernel kernel_mul_mv_id_mxfp4_f32;
|
||||
|
|
@ -559,14 +561,14 @@ struct ggml_backend_opencl_context {
|
|||
|
||||
fprintf(ftrace, "[\n");
|
||||
for (const ProfilingInfo & info : profiling_info) {
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_queued/1000);
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_submit/1000);
|
||||
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_start/1000);
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_end/1000);
|
||||
}
|
||||
fclose(ftrace);
|
||||
|
|
@ -777,6 +779,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
|
||||
|
|
@ -1991,6 +1995,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -cl-fast-relaxed-math";
|
||||
|
||||
// gemv_moe_mxfp4_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemv_moe_mxfp4_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_gemv_moe_mxfp4_f32 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_mxfp4_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemm_moe_mxfp4_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_gemm_moe_mxfp4_f32 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_LOG_CONT("\n");
|
||||
}
|
||||
|
|
@ -3299,6 +3339,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
|
|||
tensor->ne[2] == 1 && tensor->ne[3] == 1;
|
||||
}
|
||||
|
||||
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
|
||||
GGML_UNUSED(backend_ctx);
|
||||
int ne01 = tensor->ne[1];
|
||||
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
|
||||
}
|
||||
|
||||
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
|
||||
|
||||
|
|
@ -3601,14 +3647,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
int ne02 = tensor->ne[2];
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
|
||||
|
||||
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
|
||||
size_t local_work_size[3] = {64, 2, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[3] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
|
|
@ -3624,7 +3695,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
{ extra->q }
|
||||
};
|
||||
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
|
|
@ -3751,6 +3821,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
int ne02 = tensor->ne[2];
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
|
||||
|
||||
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
|
||||
size_t local_work_size[3] = {64, 2, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
|
||||
|
|
@ -7553,6 +7650,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
|||
const int ne21 = src2->ne[1];
|
||||
|
||||
const cl_ulong nb21 = src2->nb[1];
|
||||
const cl_ulong nb20 = src2->nb[0];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
|
@ -7692,6 +7790,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
|||
break;
|
||||
}
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, src0)) {
|
||||
cl_int status;
|
||||
|
||||
size_t local_size[3] = {64, 2, 1};
|
||||
size_t global_size[3] = {64, 2, 1};
|
||||
|
||||
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
|
||||
|
||||
int tile_size = 320;
|
||||
if (ne12 == 1) { // for gemv
|
||||
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
|
||||
|
||||
// create a sub_buffer for src2
|
||||
cl_buffer_region region;
|
||||
region.origin = offset2;
|
||||
region.size = ne20 * ne21 * sizeof(int);
|
||||
buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// set thread grid
|
||||
global_size[0] = static_cast<size_t>(ne01);
|
||||
global_size[1] = 4;
|
||||
global_size[2] = static_cast<size_t>(ne20);
|
||||
local_size[1] = 4;
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
|
||||
|
||||
// preprocess router table
|
||||
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
|
||||
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
|
||||
void * host_src2 = malloc(ne21 * nb21);
|
||||
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
|
||||
int total_experts = nb21 / nb20;
|
||||
int out_idx = 0;
|
||||
for (int i_expert = 0; i_expert < ne02; i_expert++) {
|
||||
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
|
||||
for (int j = 0; j < ne21; j++) {
|
||||
for (int i = 0; i < ne20; i++) {
|
||||
int expert = ((int *)host_src2)[j * total_experts + i];
|
||||
if (i_expert == expert) {
|
||||
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
|
||||
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
|
||||
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
|
||||
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
|
||||
out_idx += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// set thread grid
|
||||
global_size[0] = static_cast<size_t>(tile_size);
|
||||
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
|
||||
}
|
||||
|
||||
// create a sub_buffer for src1
|
||||
cl_buffer_region region;
|
||||
region.origin = offset1;
|
||||
region.size = ne10 * ne11 * ne12 * sizeof(float);
|
||||
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create image for src1
|
||||
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
|
||||
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Set kernel args
|
||||
int arg_idx = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
|
||||
if (ne12 == 1) {
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
|
||||
} else {
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
|
||||
}
|
||||
|
||||
// launch kernel
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(buf_src1_image));
|
||||
CL_CHECK(clReleaseMemObject(buf_src2));
|
||||
return;
|
||||
} // else fallback to generic kernel
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
|
||||
|
|
|
|||
|
|
@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_mxfp4_trans(
|
||||
global struct block_mxfp4 * src0,
|
||||
__global uint4 * dst_q,
|
||||
__global uchar * dst_e,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
int i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = src0 + src_blk_offset;
|
||||
|
||||
dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
|
||||
dst_e[dst_blk_offset] = b->e;
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4(
|
||||
global uchar * src_q,
|
||||
global half * src_e,
|
||||
|
|
@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4_trans(
|
||||
__global uint4 * src_q,
|
||||
__global uchar * src_e,
|
||||
global struct block_mxfp4 * dst,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
int i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = dst + dst_blk_offset;
|
||||
|
||||
((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
|
||||
b->e = src_e[src_blk_offset];
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q8_0
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,162 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 2
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemm_moe_mxfp4_f32(
|
||||
__global uint4 * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global ushort4 * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int tile_size
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
ushort4 router = src2[i20];
|
||||
ushort expert_id = router.x;
|
||||
ushort i11 = router.y;
|
||||
ushort i1 = router.z;
|
||||
ushort tile_id = router.w;
|
||||
|
||||
if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
|
||||
return;
|
||||
}
|
||||
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
uint tile_offset = expert_offset + tile_id * tile_size + i01;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
// load one block of q
|
||||
uint4 regQ = src0_q[tile_offset + ib00 * ne01];
|
||||
// convert 8 fp4 to fp16
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
uchar regE = src0_e[tile_offset + ib00 * ne01];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
// if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
// if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_mxfp4_f32(
|
||||
__global uint4 * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
|
||||
// load one block of q
|
||||
uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -939,6 +939,7 @@ public:
|
|||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
private:
|
||||
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
||||
|
|
@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
|
||||
uint32_t dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
size_t free, total;
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
response.free_mem = free;
|
||||
response.total_mem = total;
|
||||
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
|
||||
return true;
|
||||
}
|
||||
|
||||
rpc_server::~rpc_server() {
|
||||
for (auto buffer : buffers) {
|
||||
ggml_backend_buffer_free(buffer);
|
||||
|
|
@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() {
|
|||
}
|
||||
|
||||
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
||||
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
|
||||
sockfd_t sockfd) {
|
||||
rpc_server server(backends, cache_dir);
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
|
|
@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
auto dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
if (!server.get_device_memory(request, response)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
response.free_mem = free_mem[dev_id];
|
||||
response.total_mem = total_mem[dev_id];
|
||||
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
|
||||
response.free_mem, response.total_mem);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
}
|
||||
|
||||
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
|
||||
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
|
||||
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
|
||||
if (n_devices == 0 || devices == nullptr) {
|
||||
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
||||
return;
|
||||
}
|
||||
std::vector<ggml_backend_t> backends;
|
||||
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
|
||||
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
|
||||
printf("Starting RPC server v%d.%d.%d\n",
|
||||
RPC_PROTO_MAJOR_VERSION,
|
||||
RPC_PROTO_MINOR_VERSION,
|
||||
|
|
@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
printf("Devices:\n");
|
||||
for (size_t i = 0; i < n_devices; i++) {
|
||||
auto dev = devices[i];
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
||||
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
|
||||
total / 1024 / 1024, free / 1024 / 1024);
|
||||
auto backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (!backend) {
|
||||
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
||||
|
|
@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
}
|
||||
printf("Accepted client connection\n");
|
||||
fflush(stdout);
|
||||
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
|
||||
rpc_serve_client(backends, cache_dir, client_socket->fd);
|
||||
printf("Client connection closed\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
|||
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
|
|
@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_floor(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_ceil(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_round(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_trunc(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
|
|
@ -397,6 +445,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = start + static_cast<T>(i) * step;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
|
|
@ -565,6 +621,25 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
|||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
float start, stop, step;
|
||||
memcpy(&start, dst->op_params, sizeof(float));
|
||||
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
float * dst_ptr = (float *)dst->data;
|
||||
const int k = (int)ggml_nelements(dst);
|
||||
const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ggml_sycl_detail
|
||||
|
||||
|
||||
|
|
@ -870,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
|||
}, min_val, max_val);
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
|
|
@ -1090,3 +1217,28 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_geglu_quick(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
|
||||
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_floor(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_ceil(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_round(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_trunc(ctx, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,5 +80,11 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|||
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@
|
|||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/set_rows.hpp"
|
||||
#include "ggml-sycl/set.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml-sycl/quantize.hpp"
|
||||
|
|
@ -3619,6 +3620,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_GET_ROWS:
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
ggml_sycl_op_set(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_sycl_op_set_rows(ctx, dst);
|
||||
break;
|
||||
|
|
@ -3694,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_UNARY_OP_ELU:
|
||||
ggml_sycl_elu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
ggml_sycl_floor(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
ggml_sycl_ceil(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
ggml_sycl_round(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_sycl_trunc(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -3832,6 +3848,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_sycl_arange(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -4255,6 +4274,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
#if defined (GGML_SYCL_F16)
|
||||
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
|
||||
#else
|
||||
|
|
@ -4328,6 +4351,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_SET:
|
||||
return (op->type == GGML_TYPE_F32) &&
|
||||
(op->src[0] && op->src[1]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32) &&
|
||||
(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||
|
|
@ -4478,6 +4507,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
#define SYCL_SQRT_BLOCK_SIZE 256
|
||||
#define SYCL_SIN_BLOCK_SIZE 256
|
||||
#define SYCL_SQR_BLOCK_SIZE 256
|
||||
#define SYCL_SET_BLOCK_SIZE 256
|
||||
#define SYCL_CPY_BLOCK_SIZE 32
|
||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||
#define SYCL_CLAMP_BLOCK_SIZE 256
|
||||
|
|
@ -49,6 +50,7 @@
|
|||
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
||||
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
||||
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||
#define SYCL_ARANGE_BLOCK_SIZE 256
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
#ifndef GGML_SYCL_DMMV_X
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
#include "presets.hpp"
|
||||
#include "common.hpp"
|
||||
#include "ggml.h"
|
||||
#include "set.hpp"
|
||||
#include <cstdint>
|
||||
#include <sycl/sycl.hpp>
|
||||
using namespace sycl;
|
||||
|
||||
// Internal function: perform element-wise set operation for each thread
|
||||
inline void set_f32(const float* src, float* dst,
|
||||
const int64_t ne0, const int64_t ne1,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t nb[3], const int64_t src_nb[3],
|
||||
const int64_t offset_elem,
|
||||
const nd_item<1>& item)
|
||||
{
|
||||
const size_t idx = item.get_global_id(0);
|
||||
const size_t total = ne0 * ne1 * ne2 * ne3;
|
||||
if (idx >= total) return;
|
||||
|
||||
// Convert linear index to 4D indices
|
||||
const size_t i3 = idx / (ne2 * ne1 * ne0);
|
||||
const size_t rem = idx % (ne2 * ne1 * ne0);
|
||||
const size_t i2 = rem / (ne1 * ne0);
|
||||
const size_t rem2 = rem % (ne1 * ne0);
|
||||
const size_t i1 = rem2 / ne0;
|
||||
const size_t i0 = rem2 % ne0;
|
||||
|
||||
// Compute source and destination indices and copy
|
||||
dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] =
|
||||
src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]];
|
||||
}
|
||||
|
||||
// Main function: prepare GPU queue and launch parallel_for
|
||||
void ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const ggml_tensor* src0 = dst->src[0];
|
||||
const ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
// Ensure shapes and types are compatible
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t* opts = (const int32_t*) dst->op_params;
|
||||
const int64_t nb[3] = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)};
|
||||
const int64_t offset_elem = opts[3] / sizeof(float);
|
||||
const bool inplace = opts[4];
|
||||
|
||||
float* dst_ptr = (float*) dst->data;
|
||||
const float* src0_ptr = (const float*) src0->data;
|
||||
const float* src1_ptr = (const float*) src1->data;
|
||||
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
// Copy src0 to dst if not inplace
|
||||
if (!inplace)
|
||||
stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst));
|
||||
|
||||
const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]};
|
||||
const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)};
|
||||
|
||||
const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3];
|
||||
const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE;
|
||||
|
||||
// Copy src0 to dst if not inplace
|
||||
stream->parallel_for(
|
||||
nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)),
|
||||
[=](nd_item<1> item) {
|
||||
set_f32(src1_ptr, dst_ptr,
|
||||
ne[0], ne[1], ne[2], ne[3],
|
||||
nb, src_nb, offset_elem, item); }
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
#include "backend.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -385,6 +385,14 @@ enum shader_reduction_mode {
|
|||
|
||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||
|
||||
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
|
||||
struct vk_device_struct {
|
||||
std::recursive_mutex mutex;
|
||||
|
|
@ -582,6 +590,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||
vk_pipeline pipeline_ssm_conv_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
|
|
@ -595,6 +606,9 @@ struct vk_device_struct {
|
|||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
// [2] is {!norm, norm}
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
||||
|
|
@ -938,6 +952,11 @@ struct vk_op_multi_add_push_constants {
|
|||
static_assert(MAX_PARAMETER_COUNT == 12);
|
||||
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
||||
|
||||
struct vk_op_topk_moe_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t n_expert_used;
|
||||
};
|
||||
|
||||
struct vk_op_add_id_push_constants {
|
||||
uint32_t ne0;
|
||||
uint32_t ne1;
|
||||
|
|
@ -1087,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
struct vk_op_ssm_scan_push_constants {
|
||||
uint32_t nb02, nb03, nb12, nb13;
|
||||
uint32_t nb21, nb22, nb31;
|
||||
uint32_t nb42, nb43, nb52, nb53;
|
||||
uint32_t s_off;
|
||||
uint32_t n_head, d_head, n_group, n_tok;
|
||||
};
|
||||
struct vk_op_ssm_conv_push_constants {
|
||||
uint32_t nb01, nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t dst_nb0, dst_nb1, dst_nb2;
|
||||
uint32_t nc, ncs, nr, n_t, n_s;
|
||||
};
|
||||
|
||||
struct vk_op_conv2d_push_constants {
|
||||
uint32_t Cout;
|
||||
|
|
@ -3591,6 +3623,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -3701,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
||||
}
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
@ -7983,6 +8025,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
||||
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
return ctx->device->pipeline_topk_moe[idx][with_norm];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
||||
}
|
||||
|
|
@ -8098,6 +8147,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t d_state = src0->ne[0];
|
||||
if (d_state == 128) {
|
||||
return ctx->device->pipeline_ssm_scan_f32_d128;
|
||||
} else if (d_state == 256) {
|
||||
return ctx->device->pipeline_ssm_scan_f32_d256;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SSM_CONV:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_ssm_conv_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
|
|
@ -8592,6 +8656,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
const uint32_t nr = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1];
|
||||
const uint32_t n_s = dst->ne[2];
|
||||
elements = { nr, n_t, n_s };
|
||||
}
|
||||
break;
|
||||
default:
|
||||
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
||||
break;
|
||||
|
|
@ -9038,6 +9110,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
const ggml_tensor * src3 = dst->src[3];
|
||||
const ggml_tensor * src4 = dst->src[4];
|
||||
const ggml_tensor * src5 = dst->src[5];
|
||||
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
const uint32_t head_dim = src0->ne[1];
|
||||
const uint32_t n_head = src1->ne[1];
|
||||
const uint32_t n_group = src4->ne[1];
|
||||
const uint32_t n_tok = src1->ne[2];
|
||||
const uint32_t n_seq = src1->ne[3];
|
||||
|
||||
bool is_mamba2 = (src3->nb[1] == sizeof(float));
|
||||
GGML_ASSERT(is_mamba2);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
||||
|
||||
const vk_op_ssm_scan_push_constants pc = {
|
||||
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
|
||||
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
|
||||
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
|
||||
(uint32_t)src3->nb[1],
|
||||
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
|
||||
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
|
||||
(uint32_t)s_off,
|
||||
n_head, head_dim, n_group, n_tok
|
||||
};
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
|
||||
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
|
||||
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
|
||||
size_t dst_size = ggml_nbytes(dst);
|
||||
size_t src_sizes[GGML_MAX_SRC];
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
const int splitH = 16;
|
||||
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
|
||||
const uint32_t num_workgroups_y = n_seq;
|
||||
elements = { num_workgroups_x, num_workgroups_y, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
|
||||
(uint32_t)src1->nb[1],
|
||||
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
|
||||
(uint32_t)src1->ne[0],
|
||||
(uint32_t)src0->ne[0],
|
||||
(uint32_t)src0->ne[1],
|
||||
(uint32_t)dst->ne[1],
|
||||
(uint32_t)dst->ne[2],
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * x = dst->src[0];
|
||||
const ggml_tensor * g = dst->src[1];
|
||||
|
|
@ -9434,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||
|
||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
||||
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
const int n_experts = logits->ne[0];
|
||||
const int n_rows = logits->ne[1];
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
|
||||
|
||||
if (dryrun) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
|
||||
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
|
||||
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
||||
|
||||
vk_buffer d_logits = nullptr;
|
||||
size_t logits_buf_offset = 0;
|
||||
vk_buffer d_weights = nullptr;
|
||||
size_t weights_buf_offset = 0;
|
||||
vk_buffer d_ids = nullptr;
|
||||
size_t ids_buf_offset = 0;
|
||||
|
||||
bool logits_uma = false;
|
||||
bool weights_uma = false;
|
||||
bool ids_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
|
||||
logits_uma = d_logits != nullptr;
|
||||
weights_uma = d_weights != nullptr;
|
||||
ids_uma = d_ids != nullptr;
|
||||
}
|
||||
|
||||
if (!logits_uma) {
|
||||
d_logits = logits_buf_ctx->dev_buffer;
|
||||
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
|
||||
GGML_ASSERT(d_logits != nullptr);
|
||||
}
|
||||
if (!weights_uma) {
|
||||
d_weights = weights_buf_ctx->dev_buffer;
|
||||
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
|
||||
GGML_ASSERT(d_weights != nullptr);
|
||||
}
|
||||
if (!ids_uma) {
|
||||
d_ids = ids_buf_ctx->dev_buffer;
|
||||
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
|
||||
vk_op_topk_moe_push_constants pc;
|
||||
pc.n_rows = n_rows;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
|
||||
GGML_ASSERT(n_expert_used <= n_experts);
|
||||
|
||||
const uint32_t rows_per_block = 4;
|
||||
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
|
||||
}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
|
|
@ -10870,6 +11134,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
|
|
@ -11017,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
ctx->unsynced_nodes_read.clear();
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
}
|
||||
// Add the last fused node and all fused source nodes to the unsynchronized list.
|
||||
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
||||
ctx->unsynced_nodes_written.push_back(last_node);
|
||||
// Add all fused nodes to the unsynchronized lists.
|
||||
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
||||
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
|
||||
ctx->unsynced_nodes_written.push_back(cur_node);
|
||||
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
if (!cur_node->src[j]) {
|
||||
continue;
|
||||
|
|
@ -11188,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||
} else {
|
||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||
}
|
||||
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
|
|
@ -11287,6 +11557,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_SCAN:
|
||||
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
|
|
@ -11398,6 +11678,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
|
|
@ -11972,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, bool with_norm) {
|
||||
|
||||
if (with_norm) {
|
||||
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < topk_moe.size(); ++i) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
||||
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
float scale = op_params[0];
|
||||
float max_bias = op_params[1];
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
// n_expert must be a power of 2
|
||||
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the nodes don't have any unexpected uses
|
||||
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
||||
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
||||
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
||||
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
||||
|
||||
// softmax is used by reshape and argsort
|
||||
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
||||
reshape1->src[0] != softmax ||
|
||||
argsort->src[0] != softmax) {
|
||||
return false;
|
||||
}
|
||||
// reshape is used by get_rows
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
||||
get_rows->src[0] != reshape1) {
|
||||
return false;
|
||||
}
|
||||
// argsort is used by view
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
||||
view->src[0] != argsort) {
|
||||
return false;
|
||||
}
|
||||
// view is written (via argsort), we can skip checking it
|
||||
|
||||
if (with_norm) {
|
||||
// get_rows is used by reshape
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
||||
reshape5->src[0] != get_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// reshape is used by sum_rows and div
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
||||
sum_rows->src[0] != reshape5 ||
|
||||
div->src[0] != reshape5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// sum_rows is used by div
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
||||
div->src[1] != sum_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// div/reshape are written
|
||||
if (reshape8->src[0] != div) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ctx->device->subgroup_arithmetic ||
|
||||
!ctx->device->subgroup_shuffle ||
|
||||
!ctx->device->subgroup_require_full_support ||
|
||||
ctx->device->disable_fusion) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
|
||||
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
||||
|
|
@ -12047,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
}
|
||||
}
|
||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||
|
|
@ -12144,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -12151,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
||||
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||
(i + ctx->num_additional_fused_ops == last_node) ||
|
||||
(i + ctx->num_additional_fused_ops >= last_node) ||
|
||||
(almost_ready && !ctx->almost_ready_fence_pending);
|
||||
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
|
|
@ -12275,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
while (first_unused < graph->n_nodes) {
|
||||
std::vector<int> current_set;
|
||||
|
||||
// Avoid reordering topk_moe_norm
|
||||
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
||||
bool is_topk_moe_norm = true;
|
||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
||||
is_topk_moe_norm = false;
|
||||
}
|
||||
}
|
||||
if (is_topk_moe_norm) {
|
||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
new_order.push_back(graph->nodes[first_unused + j]);
|
||||
used[first_unused + j] = true;
|
||||
}
|
||||
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||
first_unused++;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// First, grab the next unused node.
|
||||
current_set.push_back(first_unused);
|
||||
|
||||
|
|
@ -12879,6 +13302,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint32_t d_state = op->src[0]->ne[0];
|
||||
const uint32_t head_dim = op->src[0]->ne[1];
|
||||
|
||||
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
|
||||
if (!is_mamba2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
const uint32_t SPLIT_H = 16;
|
||||
|
||||
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
|
||||
|
||||
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_SSM_CONV:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CONV_2D:
|
||||
|
|
@ -13223,14 +13687,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
|
||||
struct ggml_context * ggml_ctx = ggml_init(iparams);
|
||||
|
||||
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
|
||||
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
|
||||
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
std::array<size_t, GGML_MAX_SRC> src_size = {};
|
||||
std::array<void *, GGML_MAX_SRC> src_buffer = {};
|
||||
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
|
||||
|
||||
struct ggml_tensor * tensor_clone = nullptr;
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
ggml_tensor * srci = tensor->src[i];
|
||||
if (fused_rms_norm_mul) {
|
||||
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
||||
|
|
@ -13537,6 +14001,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
src_clone[2]);
|
||||
} else if (tensor->op == GGML_OP_ADD_ID) {
|
||||
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
||||
} else if (tensor->op == GGML_OP_SSM_SCAN) {
|
||||
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
|
||||
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_SSM_CONV) {
|
||||
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
|
@ -13558,7 +14027,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
memcpy(comp_result, tensor_clone->data, comp_size);
|
||||
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (src_buffer[i] != nullptr) {
|
||||
free(src_buffer[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float src0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float src1[]; };
|
||||
layout(binding = 2) buffer Dst { float dst[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint nb01; uint nb02;
|
||||
uint nb11;
|
||||
uint dst_nb0; uint dst_nb1; uint dst_nb2;
|
||||
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
|
||||
};
|
||||
|
||||
void main() {
|
||||
const uint global_thread_id = gl_GlobalInvocationID.x;
|
||||
const uint i2 = gl_WorkGroupID.y;
|
||||
const uint i3 = gl_WorkGroupID.z;
|
||||
|
||||
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i1 = global_thread_id;
|
||||
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
|
||||
const uint src1_base = i1 * (nb11 / 4);
|
||||
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||
|
||||
float sum = 0.0;
|
||||
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
|
||||
const uint src0_idx = src0_base + i0;
|
||||
const uint src1_idx = src1_base + i0;
|
||||
sum += src0[src0_idx] * src1[src1_idx];
|
||||
}
|
||||
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const uint D_STATE = 128;
|
||||
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
layout(constant_id = 2) const uint SPLIT_H = 16;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float s0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float x[]; };
|
||||
layout(binding = 2) readonly buffer Src2 { float dt[]; };
|
||||
layout(binding = 3) readonly buffer Src3 { float A[]; };
|
||||
layout(binding = 4) readonly buffer Src4 { float B[]; };
|
||||
layout(binding = 5) readonly buffer Src5 { float C[]; };
|
||||
layout(binding = 6) readonly buffer Src6 { int ids[]; };
|
||||
layout(binding = 7) buffer Dst { float d[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint nb02; uint nb03; uint nb12; uint nb13;
|
||||
uint nb21; uint nb22; uint nb31;
|
||||
uint nb42; uint nb43; uint nb52; uint nb53;
|
||||
uint s_off;
|
||||
uint n_head;
|
||||
uint d_head;
|
||||
uint n_group;
|
||||
uint n_tok;
|
||||
};
|
||||
|
||||
float softplus(float x) {
|
||||
if (x <= 20.0) {
|
||||
return log(1.0 + exp(x));
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
shared float stateC[SPLIT_H * D_STATE];
|
||||
|
||||
void main() {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
|
||||
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
|
||||
const uint seq_idx = gl_WorkGroupID.y;
|
||||
|
||||
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
|
||||
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
||||
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
|
||||
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
|
||||
const uint A_base_idx = (head_idx * nb31) / 4;
|
||||
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
|
||||
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
|
||||
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
|
||||
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
||||
|
||||
const uint stride_x = nb12 / 4;
|
||||
const uint stride_dt = nb21 / 4;
|
||||
const uint stride_B = nb42 / 4;
|
||||
const uint stride_C = nb52 / 4;
|
||||
const uint stride_y = n_head * d_head;
|
||||
|
||||
float state[SPLIT_H];
|
||||
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||
state[j] = s0[s0_base_idx + j * D_STATE + tid];
|
||||
}
|
||||
|
||||
for (uint i = 0; i < n_tok; i++) {
|
||||
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
|
||||
|
||||
const float dA = exp(dt_soft_plus * A[A_base_idx]);
|
||||
|
||||
const float B_val = B[B_base_idx + i * stride_B + tid];
|
||||
const float C_val = C[C_base_idx + i * stride_C + tid];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
|
||||
|
||||
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||
|
||||
stateC[j * D_STATE + tid] = state[j] * C_val;
|
||||
}
|
||||
|
||||
barrier();
|
||||
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
|
||||
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||
const uint k = (tid % (w >> 1)) +
|
||||
(D_STATE * (tid / (w >> 1))) +
|
||||
j * D_STATE * (D_STATE / (w >> 1));
|
||||
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
|
||||
stateC[k] += stateC[k + (w >> 1)];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
|
||||
const uint idx = (tid % SUBGROUP_SIZE) +
|
||||
D_STATE * (tid / SUBGROUP_SIZE) +
|
||||
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||
|
||||
uint lane = tid % SUBGROUP_SIZE;
|
||||
|
||||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
if (idx + offset < SPLIT_H * D_STATE) {
|
||||
stateC[idx] += stateC[idx + offset];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
|
||||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||
d[y_base_idx + i * stride_y + k] = stateC[idx];
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||
d[s_base_idx + j * D_STATE + tid] = state[j];
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_expert_used;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint logits_offset = n_experts * row;
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = exp(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = subgroupAdd(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
// we do the argmax reduce over n_expert_used, each time marking
|
||||
// the expert weight as -inf to exclude from the next iteration
|
||||
|
||||
float wt_sum = 0.f;
|
||||
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = subgroupShuffleXor(max_val, mask);
|
||||
const uint expert = subgroupShuffleXor(max_expert, mask);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||
output_weights[k / WARP_SIZE] = max_val;
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[ids_offset + k] = max_expert;
|
||||
if (with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
output_weights[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[weights_offset + idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -916,6 +916,12 @@ void process_shaders() {
|
|||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("topk_moe_f32", "topk_moe.comp", {});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
@ -959,7 +965,7 @@ void write_output_files() {
|
|||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (auto op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
|
||||
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ class Keys:
|
|||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
|
||||
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||
|
|
@ -400,6 +402,7 @@ class MODEL_ARCH(IntEnum):
|
|||
WAVTOKENIZER_DEC = auto()
|
||||
PLM = auto()
|
||||
BAILINGMOE = auto()
|
||||
BAILINGMOE2 = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
|
|
@ -744,6 +747,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
MODEL_ARCH.PLM: "plm",
|
||||
MODEL_ARCH.BAILINGMOE: "bailingmoe",
|
||||
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
|
|
@ -2533,6 +2537,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.BAILINGMOE2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ,
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
|
||||
MODEL_TENSOR.NEXTN_ENORM,
|
||||
MODEL_TENSOR.NEXTN_HNORM,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.DOTS1: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -755,6 +755,12 @@ class GGUFWriter:
|
|||
def add_expert_shared_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_group_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_group_used_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ class TensorNameMap:
|
|||
"h.{bid}.self_attention.query_key_value", # bloom
|
||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_key_value", # persimmon
|
||||
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
|
|
@ -260,6 +261,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||
"model.layers.{bid}.attention.dense", # bailingmoe2
|
||||
"h.{bid}.attn.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
|
|
@ -373,6 +375,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
),
|
||||
|
||||
|
|
@ -549,6 +552,7 @@ class TensorNameMap:
|
|||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"layers.{bid}.self_attn.q_norm", # embeddinggemma
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
|
|
@ -563,6 +567,7 @@ class TensorNameMap:
|
|||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"layers.{bid}.self_attn.k_norm", # embeddinggemma
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
|
|
@ -584,6 +589,7 @@ class TensorNameMap:
|
|||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
||||
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
|
||||
"model.layers.{bid}.final_layernorm", # bailingmoe2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
|
|
@ -135,6 +136,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
||||
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
|
||||
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||
|
|
@ -1946,6 +1949,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BAILINGMOE2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
|
||||
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
|
||||
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
|
||||
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
|
||||
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
|
||||
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
|
||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DOTS1,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ enum llm_arch {
|
|||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_BAILINGMOE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
|
|
@ -139,6 +140,8 @@ enum llm_kv {
|
|||
LLM_KV_EXPERT_COUNT,
|
||||
LLM_KV_EXPERT_USED_COUNT,
|
||||
LLM_KV_EXPERT_SHARED_COUNT,
|
||||
LLM_KV_EXPERT_GROUP_COUNT,
|
||||
LLM_KV_EXPERT_GROUP_USED_COUNT,
|
||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ private:
|
|||
uint32_t n_seq_max;
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
|
|
|
|||
|
|
@ -63,6 +63,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
{ "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK },
|
||||
{ "bailing2", LLM_CHAT_TEMPLATE_BAILING2 },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
|
|
@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||
return LLM_CHAT_TEMPLATE_YANDEX;
|
||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING;
|
||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("\"HUMAN\"") && tmpl_contains("<think>")) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING_THINK;
|
||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("<role>HUMAN</role>") && tmpl_contains("<|role_end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING2;
|
||||
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||
|
|
@ -644,8 +650,8 @@ int32_t llm_chat_apply_template(
|
|||
if (add_ass) {
|
||||
ss << " Ассистент:[SEP]";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
|
||||
// Bailing (Ling) template
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
|
||||
// Bailing (Ling/Ring) template
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
|
||||
|
|
@ -658,6 +664,33 @@ int32_t llm_chat_apply_template(
|
|||
ss << "<role>" << role << "</role>" << message->content;
|
||||
}
|
||||
|
||||
if (add_ass) {
|
||||
ss << "<role>ASSISTANT</role>";
|
||||
|
||||
if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
|
||||
ss << "<think>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) {
|
||||
// Bailing2 (Ling 2.0) template
|
||||
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
||||
|
||||
if (!has_system) {
|
||||
ss << "<role>SYSTEM</role>detailed thinking off<|role_end|>";
|
||||
}
|
||||
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
|
||||
if (role == "user") {
|
||||
role = "HUMAN";
|
||||
} else {
|
||||
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
|
||||
}
|
||||
|
||||
ss << "<role>" << role << "</role>" << message->content << "<|role_end|>";
|
||||
}
|
||||
|
||||
if (add_ass) {
|
||||
ss << "<role>ASSISTANT</role>";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ enum llm_chat_template {
|
|||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
LLM_CHAT_TEMPLATE_BAILING,
|
||||
LLM_CHAT_TEMPLATE_BAILING_THINK,
|
||||
LLM_CHAT_TEMPLATE_BAILING2,
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
|
|
|
|||
|
|
@ -2346,7 +2346,8 @@ llama_context * llama_init_from_model(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (params.pooling_type != model->hparams.pooling_type) {
|
||||
if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
|
||||
params.pooling_type != model->hparams.pooling_type) {
|
||||
//user-specified pooling-type is different from the model default
|
||||
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
|
||||
model->hparams.pooling_type, params.pooling_type);
|
||||
|
|
|
|||
|
|
@ -950,6 +950,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||
}
|
||||
|
||||
// select top n_group_used expert groups
|
||||
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
|
||||
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
|
||||
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
||||
|
||||
// organize experts into n_expert_groups
|
||||
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||
|
||||
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
||||
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
||||
|
||||
// get top n_group_used expert groups
|
||||
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
||||
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
||||
|
||||
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
||||
cb(expert_groups, "ffn_moe_group_topk", il);
|
||||
|
||||
// mask out the other groups
|
||||
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
||||
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
||||
cb(selection_probs, "ffn_moe_probs_masked", il);
|
||||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
|
|
@ -981,6 +1006,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||
|
||||
if (arch == LLM_ARCH_BAILINGMOE2) {
|
||||
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
|
||||
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
|
||||
}
|
||||
|
||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights_norm", il);
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,8 @@ struct llama_hparams {
|
|||
uint32_t n_ff_chexp = 0;
|
||||
uint32_t n_expert_shared = 0;
|
||||
uint32_t n_norm_groups = 0;
|
||||
uint32_t n_expert_groups = 0;
|
||||
uint32_t n_group_used = 0;
|
||||
uint32_t n_group_experts = 0;
|
||||
|
||||
float expert_group_scale = 0.05f;
|
||||
|
|
|
|||
|
|
@ -114,9 +114,12 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||
case LLM_TYPE_A13B: return "A13B";
|
||||
case LLM_TYPE_7B_A1B: return "7B.A1B";
|
||||
case LLM_TYPE_8B_A1B: return "8B.A1B";
|
||||
case LLM_TYPE_16B_A1B: return "16B.A1B";
|
||||
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||
case LLM_TYPE_100B_A6B: return "100B.A6B";
|
||||
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
||||
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
||||
case LLM_TYPE_300B_A47B: return "300B.A47B";
|
||||
|
|
@ -421,11 +424,8 @@ struct llama_model::impl {
|
|||
llama_mlocks mlock_bufs;
|
||||
llama_mlocks mlock_mmaps;
|
||||
|
||||
// contexts where the model tensors metadata is stored
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
|
||||
// the model memory buffers for the tensor data
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
// contexts where the model tensors metadata is stored as well ass the corresponding buffers:
|
||||
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||
|
||||
buft_list_t cpu_buft_list;
|
||||
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
|
||||
|
|
@ -483,11 +483,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
return;
|
||||
}
|
||||
|
||||
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
|
||||
|
||||
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
||||
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
|
||||
|
|
@ -503,8 +505,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
|
||||
if (hparams.n_expert > 0) {
|
||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||
GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
|
||||
if (hparams.n_expert_groups > 1) {
|
||||
GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
|
||||
GGML_ASSERT(hparams.n_group_used > 0);
|
||||
GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||
GGML_ASSERT(hparams.n_expert_groups == 0);
|
||||
}
|
||||
|
||||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||
|
|
@ -1846,8 +1855,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
// TODO: Add llm type label (not sure this is useful)
|
||||
switch (hparams.n_embd) {
|
||||
case 1536: type = LLM_TYPE_7B_A1B; break;
|
||||
case 2048: case 2560: type = LLM_TYPE_3B; break;
|
||||
case 4096: type = LLM_TYPE_32B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
|
|
@ -1888,6 +1899,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BAILINGMOE2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 20: type = LLM_TYPE_16B_A1B; break;
|
||||
case 21: type = LLM_TYPE_16B_A1B; break;
|
||||
case 32: type = LLM_TYPE_100B_A6B; break;
|
||||
case 33: type = LLM_TYPE_100B_A6B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DOTS1:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
|
@ -2182,7 +2216,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
max_n_tensors += n_layer*2; // duplicated rope freq tensors
|
||||
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
||||
struct ggml_backend_buft_comparator {
|
||||
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
||||
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
|
||||
}
|
||||
};
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
||||
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
|
|
@ -2197,12 +2238,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
throw std::runtime_error(format("failed to create ggml context"));
|
||||
}
|
||||
|
||||
ctx_map[buft] = ctx;
|
||||
pimpl->ctxs.emplace_back(ctx);
|
||||
ctx_map.emplace(buft, ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
return it->second;
|
||||
return it->second.get();
|
||||
};
|
||||
|
||||
const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED;
|
||||
|
|
@ -5492,6 +5532,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BAILINGMOE2:
|
||||
{
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2");
|
||||
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2");
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
int flags = 0;
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= TENSOR_SKIP;
|
||||
}
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
|
||||
const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared;
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
|
||||
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
|
||||
} else { // Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
|
||||
}
|
||||
|
||||
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags);
|
||||
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DOTS1:
|
||||
{
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
|
@ -6037,16 +6141,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
pimpl->mappings.reserve(ml.mappings.size());
|
||||
|
||||
// create the backend buffers
|
||||
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs;
|
||||
ctx_bufs.reserve(ctx_map.size());
|
||||
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps;
|
||||
ctx_buf_maps.reserve(ctx_map.size());
|
||||
|
||||
// Ensure we have enough capacity for the maximum backend buffer we will potentially create
|
||||
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
|
||||
pimpl->bufs.reserve(n_max_backend_buffer);
|
||||
pimpl->ctxs_bufs.reserve(n_max_backend_buffer);
|
||||
|
||||
for (auto & it : ctx_map) {
|
||||
ggml_backend_buffer_type_t buft = it.first;
|
||||
ggml_context * ctx = it.second;
|
||||
for (auto & [buft, ctx_ptr] : ctx_map) {
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
|
||||
// skip contexts without tensors
|
||||
if (ggml_get_first_tensor(ctx) == nullptr) {
|
||||
|
|
@ -6070,6 +6173,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
|
||||
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
|
||||
|
||||
ggml_backend_buffer_t buf = nullptr;
|
||||
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
||||
|
|
@ -6082,20 +6186,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
continue;
|
||||
}
|
||||
const size_t max_size = ggml_get_max_tensor_size(ctx);
|
||||
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
||||
buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
||||
if (buf == nullptr) {
|
||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||
}
|
||||
pimpl->bufs.emplace_back(buf);
|
||||
buf_map.emplace(idx, buf);
|
||||
}
|
||||
}
|
||||
else {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (buf == nullptr) {
|
||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||
}
|
||||
pimpl->bufs.emplace_back(buf);
|
||||
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
|
||||
pimpl->mlock_bufs.emplace_back(new llama_mlock);
|
||||
auto & mlock_buf = pimpl->mlock_bufs.back();
|
||||
|
|
@ -6106,10 +6208,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
buf_map.emplace(idx, buf);
|
||||
}
|
||||
}
|
||||
|
||||
if (pimpl->bufs.empty()) {
|
||||
throw std::runtime_error("failed to allocate buffer");
|
||||
}
|
||||
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
|
||||
|
||||
for (auto & buf : buf_map) {
|
||||
// indicate that this buffer contains weights
|
||||
|
|
@ -6117,7 +6216,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
}
|
||||
|
||||
ctx_bufs.emplace_back(ctx, buf_map);
|
||||
ctx_buf_maps.emplace_back(ctx, buf_map);
|
||||
}
|
||||
|
||||
if (llama_supports_gpu_offload()) {
|
||||
|
|
@ -6135,22 +6234,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
// print memory requirements per buffer type
|
||||
for (auto & buf : pimpl->bufs) {
|
||||
for (auto & [_, buf] : pimpl->ctxs_bufs) {
|
||||
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// populate tensors_by_name
|
||||
for (auto & ctx : pimpl->ctxs) {
|
||||
for (auto & [ctx, _] : pimpl->ctxs_bufs) {
|
||||
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
|
||||
tensors_by_name.emplace_back(ggml_get_name(cur), cur);
|
||||
}
|
||||
}
|
||||
|
||||
// load tensor data
|
||||
for (auto & it : ctx_bufs) {
|
||||
ggml_context * ctx = it.first;
|
||||
auto & bufs = it.second;
|
||||
if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
|
||||
for (auto & [ctx, buf_map] : ctx_buf_maps) {
|
||||
if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -6190,8 +6287,8 @@ size_t llama_model::n_devices() const {
|
|||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
|
||||
for (const auto & [_, buf] : pimpl->ctxs_bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
@ -6354,6 +6451,19 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_BAILINGMOE2) {
|
||||
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
|
|
@ -17043,6 +17153,150 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_bailingmoe2 : public llm_graph_context {
|
||||
llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA);
|
||||
cb(sa_out, "sa_out", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_norm(sa_out,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
ggml_tensor * moe_out =
|
||||
build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
{
|
||||
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, sa_out);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_dots1 : public llm_graph_context {
|
||||
llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
|
@ -19839,6 +20093,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_BAILINGMOE2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_bailingmoe2>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
llm = std::make_unique<llm_build_seed_oss>(*this, params);
|
||||
|
|
@ -20105,6 +20363,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_EXAONE:
|
||||
case LLM_ARCH_EXAONE4:
|
||||
case LLM_ARCH_MINICPM3:
|
||||
case LLM_ARCH_BAILINGMOE2:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
|
|
|
|||
|
|
@ -107,9 +107,12 @@ enum llm_type {
|
|||
LLM_TYPE_17B_16E, // llama4 Scout
|
||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||
LLM_TYPE_A13B,
|
||||
LLM_TYPE_7B_A1B,
|
||||
LLM_TYPE_8B_A1B, // lfm2moe
|
||||
LLM_TYPE_16B_A1B,
|
||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||
LLM_TYPE_30B_A3B,
|
||||
LLM_TYPE_100B_A6B,
|
||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||
LLM_TYPE_235B_A22B,
|
||||
LLM_TYPE_300B_A47B, // Ernie MoE big
|
||||
|
|
|
|||
|
|
@ -1968,6 +1968,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "bailingmoe" ||
|
||||
tokenizer_pre == "bailingmoe2" ||
|
||||
tokenizer_pre == "llada-moe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||
clean_spaces = false;
|
||||
|
|
|
|||
|
|
@ -3759,6 +3759,130 @@ struct test_clamp : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_FLOOR
|
||||
struct test_floor : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_floor(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 2, 2, 2})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_floor(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -10.0f, 10.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CEIL
|
||||
struct test_ceil : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_ceil(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 2, 2, 2})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_ceil(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -10.0f, 10.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ROUND
|
||||
struct test_round : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_round(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 2, 2, 2})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_round(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -10.0f, 10.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_TRUNC
|
||||
struct test_trunc : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_trunc(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 2, 2, 2})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_trunc(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -10.0f, 10.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_DIAG_MASK_INF
|
||||
struct test_diag_mask_inf : public test_case {
|
||||
const ggml_type type;
|
||||
|
|
@ -6585,6 +6709,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cos (type));
|
||||
test_cases.emplace_back(new test_clamp (type));
|
||||
test_cases.emplace_back(new test_leaky_relu(type));
|
||||
test_cases.emplace_back(new test_floor (type));
|
||||
test_cases.emplace_back(new test_ceil (type));
|
||||
test_cases.emplace_back(new test_round (type));
|
||||
test_cases.emplace_back(new test_trunc (type));
|
||||
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
|
||||
|
|
@ -6592,6 +6720,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
||||
|
|
@ -6989,6 +7121,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
||||
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
|
||||
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
|
||||
|
||||
|
|
|
|||
|
|
@ -301,6 +301,30 @@ static void test_simple_grammar() {
|
|||
"0123",
|
||||
}
|
||||
);
|
||||
test_schema(
|
||||
"min 1 max 900719925474091",
|
||||
// Schema
|
||||
R"""({
|
||||
"type": "integer",
|
||||
"exclusiveMinimum": 0,
|
||||
"maximum": 900719925474091
|
||||
})""",
|
||||
// Passing strings
|
||||
{
|
||||
"1",
|
||||
"2",
|
||||
"10",
|
||||
"900719925474090",
|
||||
"900719925474091",
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"0",
|
||||
"01",
|
||||
"900719925474092",
|
||||
"9007199254740910",
|
||||
}
|
||||
);
|
||||
test_schema(
|
||||
"min -1 max 1",
|
||||
R"""({
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@
|
|||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
|
||||
// vision-specific
|
||||
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
|
|
@ -48,6 +49,7 @@
|
|||
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
|
||||
|
||||
// audio-specific
|
||||
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
|
||||
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
||||
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
|
||||
|
||||
|
|
|
|||
|
|
@ -2221,15 +2221,27 @@ struct clip_model_loader {
|
|||
// projector type
|
||||
std::string proj_type;
|
||||
{
|
||||
// default key
|
||||
get_string(KEY_PROJ_TYPE, proj_type, false);
|
||||
if (!proj_type.empty()) {
|
||||
model.proj_type = clip_projector_type_from_string(proj_type);
|
||||
|
||||
// for models with mixed modalities
|
||||
if (proj_type.empty()) {
|
||||
if (modality == CLIP_MODALITY_VISION) {
|
||||
get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
|
||||
} else if (modality == CLIP_MODALITY_AUDIO) {
|
||||
get_string(KEY_AUDIO_PROJ_TYPE, proj_type, false);
|
||||
} else {
|
||||
GGML_ABORT("unknown modality");
|
||||
}
|
||||
}
|
||||
|
||||
model.proj_type = clip_projector_type_from_string(proj_type);
|
||||
|
||||
if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
|
||||
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
|
||||
}
|
||||
|
||||
// correct arch for multimodal models
|
||||
// correct arch for multimodal models (legacy method)
|
||||
if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
|
||||
model.proj_type = modality == CLIP_MODALITY_VISION
|
||||
? PROJECTOR_TYPE_QWEN25VL
|
||||
|
|
|
|||
|
|
@ -137,7 +137,6 @@ struct rpc_server_params {
|
|||
bool use_cache = false;
|
||||
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
||||
std::vector<std::string> devices;
|
||||
std::vector<size_t> dev_mem;
|
||||
};
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
||||
|
|
@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
|||
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
|
||||
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
|
||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
|
@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
|||
}
|
||||
} else if (arg == "-c" || arg == "--cache") {
|
||||
params.use_cache = true;
|
||||
} else if (arg == "-m" || arg == "--mem") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
const std::regex regex{ R"([,/]+)" };
|
||||
std::string mem_str = argv[i];
|
||||
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
|
||||
std::sregex_token_iterator end;
|
||||
for ( ; iter != end; ++iter) {
|
||||
try {
|
||||
size_t mem = std::stoul(*iter) * 1024 * 1024;
|
||||
params.dev_mem.push_back(mem);
|
||||
} catch (const std::exception & ) {
|
||||
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
|
|
@ -293,18 +274,6 @@ int main(int argc, char * argv[]) {
|
|||
return 1;
|
||||
}
|
||||
std::string endpoint = params.host + ":" + std::to_string(params.port);
|
||||
std::vector<size_t> free_mem, total_mem;
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
if (i < params.dev_mem.size()) {
|
||||
free_mem.push_back(params.dev_mem[i]);
|
||||
total_mem.push_back(params.dev_mem[i]);
|
||||
} else {
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(devices[i], &free, &total);
|
||||
free_mem.push_back(free);
|
||||
total_mem.push_back(total);
|
||||
}
|
||||
}
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str;
|
||||
if (params.use_cache) {
|
||||
|
|
@ -328,7 +297,6 @@ int main(int argc, char * argv[]) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
|
||||
devices.data(), free_mem.data(), total_mem.data());
|
||||
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -50,6 +50,7 @@
|
|||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"http-server": "^14.1.1",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
|
|
@ -2979,6 +2980,13 @@
|
|||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/async": {
|
||||
"version": "3.2.6",
|
||||
"resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz",
|
||||
"integrity": "sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/axe-core": {
|
||||
"version": "4.10.3",
|
||||
"resolved": "https://registry.npmjs.org/axe-core/-/axe-core-4.10.3.tgz",
|
||||
|
|
@ -3015,6 +3023,19 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/basic-auth": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/basic-auth/-/basic-auth-2.0.1.tgz",
|
||||
"integrity": "sha512-NF+epuEdnUYVlGuhaxbbq+dvJttwLnGY+YixlXlME5KpQ5W3CnXA5cVTneY3SPbPDRkcjMbifrwmFYcClgOZeg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"safe-buffer": "5.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/better-opn": {
|
||||
"version": "3.0.2",
|
||||
"resolved": "https://registry.npmjs.org/better-opn/-/better-opn-3.0.2.tgz",
|
||||
|
|
@ -3125,6 +3146,37 @@
|
|||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/call-bind-apply-helpers": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
|
||||
"integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
"function-bind": "^1.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/call-bound": {
|
||||
"version": "1.0.4",
|
||||
"resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz",
|
||||
"integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.2",
|
||||
"get-intrinsic": "^1.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/callsites": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
|
||||
|
|
@ -3335,6 +3387,16 @@
|
|||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/corser": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/corser/-/corser-2.0.1.tgz",
|
||||
"integrity": "sha512-utCYNzRSQIZNPIcGZdQc92UVJYAhtGAteCFg0yRaFm8f0P+CPtyGyHXJcGXnffjCybUCEx3FQ2G7U3/o9eIkVQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/cross-spawn": {
|
||||
"version": "7.0.6",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
|
||||
|
|
@ -3520,6 +3582,21 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/dunder-proto": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
|
||||
"integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.1",
|
||||
"es-errors": "^1.3.0",
|
||||
"gopd": "^1.2.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/enhanced-resolve": {
|
||||
"version": "5.18.2",
|
||||
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.2.tgz",
|
||||
|
|
@ -3547,6 +3624,26 @@
|
|||
"url": "https://github.com/fb55/entities?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/es-define-property": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
|
||||
"integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-errors": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
|
||||
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-module-lexer": {
|
||||
"version": "1.7.0",
|
||||
"resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz",
|
||||
|
|
@ -3554,6 +3651,19 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/es-object-atoms": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
|
||||
"integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-toolkit": {
|
||||
"version": "1.39.7",
|
||||
"resolved": "https://registry.npmjs.org/es-toolkit/-/es-toolkit-1.39.7.tgz",
|
||||
|
|
@ -3885,6 +3995,13 @@
|
|||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/eventemitter3": {
|
||||
"version": "4.0.7",
|
||||
"resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz",
|
||||
"integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/expect-type": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.2.2.tgz",
|
||||
|
|
@ -4058,6 +4175,27 @@
|
|||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/follow-redirects": {
|
||||
"version": "1.15.11",
|
||||
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz",
|
||||
"integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==",
|
||||
"dev": true,
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://github.com/sponsors/RubenVerborgh"
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=4.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"debug": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/fsevents": {
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
|
||||
|
|
@ -4073,6 +4211,55 @@
|
|||
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/function-bind": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
|
||||
"integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/get-intrinsic": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
|
||||
"integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.2",
|
||||
"es-define-property": "^1.0.1",
|
||||
"es-errors": "^1.3.0",
|
||||
"es-object-atoms": "^1.1.1",
|
||||
"function-bind": "^1.1.2",
|
||||
"get-proto": "^1.0.1",
|
||||
"gopd": "^1.2.0",
|
||||
"has-symbols": "^1.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"math-intrinsics": "^1.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/get-proto": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
|
||||
"integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"dunder-proto": "^1.0.1",
|
||||
"es-object-atoms": "^1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/glob-parent": {
|
||||
"version": "6.0.2",
|
||||
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
|
||||
|
|
@ -4099,6 +4286,19 @@
|
|||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/gopd": {
|
||||
"version": "1.2.0",
|
||||
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
|
||||
"integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/graceful-fs": {
|
||||
"version": "4.2.11",
|
||||
"resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz",
|
||||
|
|
@ -4123,6 +4323,32 @@
|
|||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/has-symbols": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
|
||||
"integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/hasown": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
|
||||
"integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"function-bind": "^1.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-dom": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-from-dom/-/hast-util-from-dom-5.0.1.tgz",
|
||||
|
|
@ -4363,6 +4589,16 @@
|
|||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/he": {
|
||||
"version": "1.2.0",
|
||||
"resolved": "https://registry.npmjs.org/he/-/he-1.2.0.tgz",
|
||||
"integrity": "sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"he": "bin/he"
|
||||
}
|
||||
},
|
||||
"node_modules/highlight.js": {
|
||||
"version": "11.11.1",
|
||||
"resolved": "https://registry.npmjs.org/highlight.js/-/highlight.js-11.11.1.tgz",
|
||||
|
|
@ -4372,6 +4608,19 @@
|
|||
"node": ">=12.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/html-encoding-sniffer": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-3.0.0.tgz",
|
||||
"integrity": "sha512-oWv4T4yJ52iKrufjnyZPkrN0CH3QnrUqdB6In1g5Fe1mia8GmF36gnfNySxoZtxD5+NmYw1EElVXiBk93UeskA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"whatwg-encoding": "^2.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/html-void-elements": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
|
||||
|
|
@ -4382,6 +4631,62 @@
|
|||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/http-proxy": {
|
||||
"version": "1.18.1",
|
||||
"resolved": "https://registry.npmjs.org/http-proxy/-/http-proxy-1.18.1.tgz",
|
||||
"integrity": "sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"eventemitter3": "^4.0.0",
|
||||
"follow-redirects": "^1.0.0",
|
||||
"requires-port": "^1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/http-server": {
|
||||
"version": "14.1.1",
|
||||
"resolved": "https://registry.npmjs.org/http-server/-/http-server-14.1.1.tgz",
|
||||
"integrity": "sha512-+cbxadF40UXd9T01zUHgA+rlo2Bg1Srer4+B4NwIHdaGxAGGv59nYRnGGDJ9LBk7alpS0US+J+bLLdQOOkJq4A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"basic-auth": "^2.0.1",
|
||||
"chalk": "^4.1.2",
|
||||
"corser": "^2.0.1",
|
||||
"he": "^1.2.0",
|
||||
"html-encoding-sniffer": "^3.0.0",
|
||||
"http-proxy": "^1.18.1",
|
||||
"mime": "^1.6.0",
|
||||
"minimist": "^1.2.6",
|
||||
"opener": "^1.5.1",
|
||||
"portfinder": "^1.0.28",
|
||||
"secure-compare": "3.0.1",
|
||||
"union": "~0.5.0",
|
||||
"url-join": "^4.0.1"
|
||||
},
|
||||
"bin": {
|
||||
"http-server": "bin/http-server"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/iconv-lite": {
|
||||
"version": "0.6.3",
|
||||
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz",
|
||||
"integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"safer-buffer": ">= 2.1.2 < 3.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/ignore": {
|
||||
"version": "5.3.2",
|
||||
"resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz",
|
||||
|
|
@ -5008,6 +5313,16 @@
|
|||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/math-intrinsics": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
|
||||
"integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/mdast": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
|
||||
|
|
@ -5976,6 +6291,19 @@
|
|||
"url": "https://github.com/sponsors/jonschlinkert"
|
||||
}
|
||||
},
|
||||
"node_modules/mime": {
|
||||
"version": "1.6.0",
|
||||
"resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz",
|
||||
"integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"mime": "cli.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/min-indent": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz",
|
||||
|
|
@ -6009,6 +6337,16 @@
|
|||
"node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/minimist": {
|
||||
"version": "1.2.8",
|
||||
"resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz",
|
||||
"integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/minipass": {
|
||||
"version": "7.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz",
|
||||
|
|
@ -6124,6 +6462,19 @@
|
|||
"tslib": "^2.0.3"
|
||||
}
|
||||
},
|
||||
"node_modules/object-inspect": {
|
||||
"version": "1.13.4",
|
||||
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
|
||||
"integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/open": {
|
||||
"version": "8.4.2",
|
||||
"resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz",
|
||||
|
|
@ -6142,6 +6493,16 @@
|
|||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/opener": {
|
||||
"version": "1.5.2",
|
||||
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
|
||||
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==",
|
||||
"dev": true,
|
||||
"license": "(WTFPL OR MIT)",
|
||||
"bin": {
|
||||
"opener": "bin/opener-bin.js"
|
||||
}
|
||||
},
|
||||
"node_modules/optionator": {
|
||||
"version": "0.9.4",
|
||||
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
|
||||
|
|
@ -6330,6 +6691,20 @@
|
|||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/portfinder": {
|
||||
"version": "1.0.38",
|
||||
"resolved": "https://registry.npmjs.org/portfinder/-/portfinder-1.0.38.tgz",
|
||||
"integrity": "sha512-rEwq/ZHlJIKw++XtLAO8PPuOQA/zaPJOZJ37BVuN97nLpMJeuDVLVGRwbFoBgLudgdTMP2hdRJP++H+8QOA3vg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"async": "^3.2.6",
|
||||
"debug": "^4.3.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 10.12"
|
||||
}
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.6",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz",
|
||||
|
|
@ -6680,6 +7055,22 @@
|
|||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.0",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz",
|
||||
"integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==",
|
||||
"dev": true,
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"side-channel": "^1.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.6"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/queue-microtask": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
|
||||
|
|
@ -6959,6 +7350,13 @@
|
|||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/requires-port": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz",
|
||||
"integrity": "sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/resolve-from": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz",
|
||||
|
|
@ -7072,6 +7470,20 @@
|
|||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/safe-buffer": {
|
||||
"version": "5.1.2",
|
||||
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz",
|
||||
"integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/safer-buffer": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
|
||||
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/scheduler": {
|
||||
"version": "0.26.0",
|
||||
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.26.0.tgz",
|
||||
|
|
@ -7079,6 +7491,13 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/secure-compare": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/secure-compare/-/secure-compare-3.0.1.tgz",
|
||||
"integrity": "sha512-AckIIV90rPDcBcglUwXPF3kg0P0qmPsPXAj6BBEENQE1p5yA1xfmDJzfi1Tappj37Pv2mVbKpL3Z1T+Nn7k1Qw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/semver": {
|
||||
"version": "7.7.2",
|
||||
"resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz",
|
||||
|
|
@ -7122,6 +7541,82 @@
|
|||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/side-channel": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz",
|
||||
"integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
"object-inspect": "^1.13.3",
|
||||
"side-channel-list": "^1.0.0",
|
||||
"side-channel-map": "^1.0.1",
|
||||
"side-channel-weakmap": "^1.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/side-channel-list": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz",
|
||||
"integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
"object-inspect": "^1.13.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/side-channel-map": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz",
|
||||
"integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bound": "^1.0.2",
|
||||
"es-errors": "^1.3.0",
|
||||
"get-intrinsic": "^1.2.5",
|
||||
"object-inspect": "^1.13.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/side-channel-weakmap": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz",
|
||||
"integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bound": "^1.0.2",
|
||||
"es-errors": "^1.3.0",
|
||||
"get-intrinsic": "^1.2.5",
|
||||
"object-inspect": "^1.13.3",
|
||||
"side-channel-map": "^1.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/siginfo": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz",
|
||||
|
|
@ -7904,6 +8399,18 @@
|
|||
"integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/union": {
|
||||
"version": "0.5.0",
|
||||
"resolved": "https://registry.npmjs.org/union/-/union-0.5.0.tgz",
|
||||
"integrity": "sha512-N6uOhuW6zO95P3Mel2I2zMsbsanvvtgn6jVqJv4vbVcz/JN0OkL9suomjQGmWtxJQXOCqUJvquc1sMeNz/IwlA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"qs": "^6.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/unist-util-find-after": {
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/unist-util-find-after/-/unist-util-find-after-5.0.0.tgz",
|
||||
|
|
@ -8073,6 +8580,13 @@
|
|||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/url-join": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/url-join/-/url-join-4.0.1.tgz",
|
||||
"integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/util-deprecate": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
|
||||
|
|
@ -8447,6 +8961,19 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/whatwg-encoding": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/whatwg-encoding/-/whatwg-encoding-2.0.0.tgz",
|
||||
"integrity": "sha512-p41ogyeMUrw3jWclHWTQg1k05DSVXPLcVxRTYsXUk+ZooOCZLcoYgPZ/HL/D/N+uQPOtcp1me1WhBEaX02mhWg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"iconv-lite": "0.6.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/which": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@
|
|||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"http-server": "^14.1.1",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ import { defineConfig } from '@playwright/test';
|
|||
|
||||
export default defineConfig({
|
||||
webServer: {
|
||||
command: 'npm run build && npx http-server ../public -p 8181',
|
||||
port: 8181
|
||||
command: 'npm run build && http-server ../public -p 8181',
|
||||
port: 8181,
|
||||
timeout: 120000,
|
||||
reuseExistingServer: false
|
||||
},
|
||||
testDir: 'e2e'
|
||||
});
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ import type {
|
|||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile
|
||||
DatabaseMessageExtraPdfFile,
|
||||
DatabaseMessageExtraLegacyContext
|
||||
} from '$lib/types/database';
|
||||
|
||||
import type {
|
||||
|
|
@ -73,6 +74,7 @@ declare global {
|
|||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
DatabaseMessageExtraLegacyContext,
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType,
|
||||
|
|
|
|||
|
|
@ -94,6 +94,17 @@
|
|||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'context') {
|
||||
// Legacy format from old webui - treat as text file
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'text',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'audioFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
MimeTypeImage,
|
||||
MimeTypeText
|
||||
} from '$lib/enums/files';
|
||||
import { isIMEComposing } from '$lib/utils/is-ime-composing';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
|
|
@ -97,7 +98,7 @@
|
|||
}
|
||||
|
||||
async function handleKeydown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
|
||||
event.preventDefault();
|
||||
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { getDeletionInfo } from '$lib/stores/chat.svelte';
|
||||
import { copyToClipboard } from '$lib/utils/copy';
|
||||
import { isIMEComposing } from '$lib/utils/is-ime-composing';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
|
||||
|
|
@ -93,7 +94,9 @@
|
|||
}
|
||||
|
||||
function handleEditKeydown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
// Check for IME composition using isComposing property and keyCode 229 (specifically for IME composition on Safari)
|
||||
// This prevents saving edit when confirming IME word selection (e.g., Japanese/Chinese input)
|
||||
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
|
||||
event.preventDefault();
|
||||
handleSaveEdit();
|
||||
} else if (event.key === 'Escape') {
|
||||
|
|
|
|||
|
|
@ -7,18 +7,19 @@
|
|||
|
||||
const processingState = useProcessingState();
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
let processingDetails = $derived(processingState.getProcessingDetails());
|
||||
let showSlotsInfo = $derived(isCurrentConversationLoading || config().keepStatsVisible);
|
||||
|
||||
let showSlotsInfo = $derived(isLoading() || config().keepStatsVisible);
|
||||
|
||||
// Track loading state reactively by checking if conversation ID is in loading conversations array
|
||||
$effect(() => {
|
||||
const keepStatsVisible = config().keepStatsVisible;
|
||||
|
||||
if (keepStatsVisible || isLoading()) {
|
||||
if (keepStatsVisible || isCurrentConversationLoading) {
|
||||
processingState.startMonitoring();
|
||||
}
|
||||
|
||||
if (!isLoading() && !keepStatsVisible) {
|
||||
if (!isCurrentConversationLoading && !keepStatsVisible) {
|
||||
setTimeout(() => {
|
||||
if (!config().keepStatsVisible) {
|
||||
processingState.stopMonitoring();
|
||||
|
|
@ -27,18 +28,20 @@
|
|||
}
|
||||
});
|
||||
|
||||
// Update processing state from stored timings
|
||||
$effect(() => {
|
||||
activeConversation();
|
||||
|
||||
const conversation = activeConversation();
|
||||
const messages = activeMessages() as DatabaseMessage[];
|
||||
const keepStatsVisible = config().keepStatsVisible;
|
||||
|
||||
if (keepStatsVisible) {
|
||||
if (keepStatsVisible && conversation) {
|
||||
if (messages.length === 0) {
|
||||
slotsService.clearState();
|
||||
slotsService.clearConversationState(conversation.id);
|
||||
return;
|
||||
}
|
||||
|
||||
// Search backwards through messages to find most recent assistant message with timing data
|
||||
// Using reverse iteration for performance - avoids array copy and stops at first match
|
||||
let foundTimingData = false;
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
|
|
@ -47,15 +50,18 @@
|
|||
foundTimingData = true;
|
||||
|
||||
slotsService
|
||||
.updateFromTimingData({
|
||||
prompt_n: message.timings.prompt_n || 0,
|
||||
predicted_n: message.timings.predicted_n || 0,
|
||||
predicted_per_second:
|
||||
message.timings.predicted_n && message.timings.predicted_ms
|
||||
? (message.timings.predicted_n / message.timings.predicted_ms) * 1000
|
||||
: 0,
|
||||
cache_n: message.timings.cache_n || 0
|
||||
})
|
||||
.updateFromTimingData(
|
||||
{
|
||||
prompt_n: message.timings.prompt_n || 0,
|
||||
predicted_n: message.timings.predicted_n || 0,
|
||||
predicted_per_second:
|
||||
message.timings.predicted_n && message.timings.predicted_ms
|
||||
? (message.timings.predicted_n / message.timings.predicted_ms) * 1000
|
||||
: 0,
|
||||
cache_n: message.timings.cache_n || 0
|
||||
},
|
||||
conversation.id
|
||||
)
|
||||
.catch((error) => {
|
||||
console.warn('Failed to update processing state from stored timings:', error);
|
||||
});
|
||||
|
|
@ -64,7 +70,7 @@
|
|||
}
|
||||
|
||||
if (!foundTimingData) {
|
||||
slotsService.clearState();
|
||||
slotsService.clearConversationState(conversation.id);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@
|
|||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
const conversation = activeConversation();
|
||||
if (conversation) {
|
||||
|
|
@ -254,7 +256,7 @@
|
|||
});
|
||||
|
||||
$effect(() => {
|
||||
if (isLoading() && autoScrollEnabled) {
|
||||
if (isCurrentConversationLoading && autoScrollEnabled) {
|
||||
scrollInterval = setInterval(scrollChatToBottom, AUTO_SCROLL_INTERVAL);
|
||||
} else if (scrollInterval) {
|
||||
clearInterval(scrollInterval);
|
||||
|
|
@ -305,7 +307,7 @@
|
|||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
|
||||
<ChatForm
|
||||
isLoading={isLoading()}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
|
|
@ -348,7 +350,7 @@
|
|||
|
||||
<div in:fly={{ y: 10, duration: 250, delay: 300 }}>
|
||||
<ChatForm
|
||||
isLoading={isLoading()}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,16 @@
|
|||
Funnel,
|
||||
AlertTriangle,
|
||||
Brain,
|
||||
Cog,
|
||||
Code,
|
||||
Monitor,
|
||||
Sun,
|
||||
Moon,
|
||||
ChevronLeft,
|
||||
ChevronRight
|
||||
ChevronRight,
|
||||
Database
|
||||
} from '@lucide/svelte';
|
||||
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app';
|
||||
import ImportExportTab from './ImportExportTab.svelte';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { config, updateMultipleConfig } from '$lib/stores/settings.svelte';
|
||||
|
|
@ -88,9 +90,59 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
title: 'Samplers',
|
||||
title: 'Sampling',
|
||||
icon: Funnel,
|
||||
fields: [
|
||||
{
|
||||
key: 'temperature',
|
||||
label: 'Temperature',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'dynatemp_range',
|
||||
label: 'Dynamic temperature range',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'dynatemp_exponent',
|
||||
label: 'Dynamic temperature exponent',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'top_k',
|
||||
label: 'Top K',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'top_p',
|
||||
label: 'Top P',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'min_p',
|
||||
label: 'Min P',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'xtc_probability',
|
||||
label: 'XTC probability',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'xtc_threshold',
|
||||
label: 'XTC threshold',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'typ_p',
|
||||
label: 'Typical P',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'max_tokens',
|
||||
label: 'Max tokens',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'samplers',
|
||||
label: 'Samplers',
|
||||
|
|
@ -152,68 +204,22 @@
|
|||
key: 'showThoughtInProgress',
|
||||
label: 'Show thought in progress',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'disableReasoningFormat',
|
||||
label:
|
||||
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
title: 'Advanced',
|
||||
icon: Cog,
|
||||
title: 'Import/Export',
|
||||
icon: Database,
|
||||
fields: []
|
||||
},
|
||||
{
|
||||
title: 'Developer',
|
||||
icon: Code,
|
||||
fields: [
|
||||
{
|
||||
key: 'temperature',
|
||||
label: 'Temperature',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'dynatemp_range',
|
||||
label: 'Dynamic temperature range',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'dynatemp_exponent',
|
||||
label: 'Dynamic temperature exponent',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'top_k',
|
||||
label: 'Top K',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'top_p',
|
||||
label: 'Top P',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'min_p',
|
||||
label: 'Min P',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'xtc_probability',
|
||||
label: 'XTC probability',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'xtc_threshold',
|
||||
label: 'XTC threshold',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'typ_p',
|
||||
label: 'Typical P',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'max_tokens',
|
||||
label: 'Max tokens',
|
||||
type: 'input'
|
||||
key: 'disableReasoningFormat',
|
||||
label: 'Show raw LLM output',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'custom',
|
||||
|
|
@ -456,21 +462,25 @@
|
|||
|
||||
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
|
||||
<div class="space-y-6 p-4 md:p-6">
|
||||
<div>
|
||||
<div class="grid">
|
||||
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
||||
<currentSection.icon class="h-5 w-5" />
|
||||
|
||||
<h3 class="text-lg font-semibold">{currentSection.title}</h3>
|
||||
</div>
|
||||
|
||||
<div class="space-y-6">
|
||||
<ChatSettingsFields
|
||||
fields={currentSection.fields}
|
||||
{localConfig}
|
||||
onConfigChange={handleConfigChange}
|
||||
onThemeChange={handleThemeChange}
|
||||
/>
|
||||
</div>
|
||||
{#if currentSection.title === 'Import/Export'}
|
||||
<ImportExportTab />
|
||||
{:else}
|
||||
<div class="space-y-6">
|
||||
<ChatSettingsFields
|
||||
fields={currentSection.fields}
|
||||
{localConfig}
|
||||
onConfigChange={handleConfigChange}
|
||||
onThemeChange={handleThemeChange}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="mt-8 border-t pt-6">
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
<script lang="ts">
|
||||
import { Search, X } from '@lucide/svelte';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { SvelteSet } from 'svelte/reactivity';
|
||||
|
||||
interface Props {
|
||||
conversations: DatabaseConversation[];
|
||||
messageCountMap?: Map<string, number>;
|
||||
mode: 'export' | 'import';
|
||||
onCancel: () => void;
|
||||
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
conversations,
|
||||
messageCountMap = new Map(),
|
||||
mode,
|
||||
onCancel,
|
||||
onConfirm,
|
||||
open = $bindable(false)
|
||||
}: Props = $props();
|
||||
|
||||
let searchQuery = $state('');
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
|
||||
let lastClickedId = $state<string | null>(null);
|
||||
|
||||
let filteredConversations = $derived(
|
||||
conversations.filter((conv) => {
|
||||
const name = conv.name || 'Untitled conversation';
|
||||
return name.toLowerCase().includes(searchQuery.toLowerCase());
|
||||
})
|
||||
);
|
||||
|
||||
let allSelected = $derived(
|
||||
filteredConversations.length > 0 &&
|
||||
filteredConversations.every((conv) => selectedIds.has(conv.id))
|
||||
);
|
||||
|
||||
let someSelected = $derived(
|
||||
filteredConversations.some((conv) => selectedIds.has(conv.id)) && !allSelected
|
||||
);
|
||||
|
||||
function toggleConversation(id: string, shiftKey: boolean = false) {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
if (shiftKey && lastClickedId !== null) {
|
||||
const lastIndex = filteredConversations.findIndex((c) => c.id === lastClickedId);
|
||||
const currentIndex = filteredConversations.findIndex((c) => c.id === id);
|
||||
|
||||
if (lastIndex !== -1 && currentIndex !== -1) {
|
||||
const start = Math.min(lastIndex, currentIndex);
|
||||
const end = Math.max(lastIndex, currentIndex);
|
||||
|
||||
const shouldSelect = !newSet.has(id);
|
||||
|
||||
for (let i = start; i <= end; i++) {
|
||||
if (shouldSelect) {
|
||||
newSet.add(filteredConversations[i].id);
|
||||
} else {
|
||||
newSet.delete(filteredConversations[i].id);
|
||||
}
|
||||
}
|
||||
|
||||
selectedIds = newSet;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (newSet.has(id)) {
|
||||
newSet.delete(id);
|
||||
} else {
|
||||
newSet.add(id);
|
||||
}
|
||||
|
||||
selectedIds = newSet;
|
||||
lastClickedId = id;
|
||||
}
|
||||
|
||||
function toggleAll() {
|
||||
if (allSelected) {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
filteredConversations.forEach((conv) => newSet.delete(conv.id));
|
||||
selectedIds = newSet;
|
||||
} else {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
filteredConversations.forEach((conv) => newSet.add(conv.id));
|
||||
selectedIds = newSet;
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirm() {
|
||||
const selected = conversations.filter((conv) => selectedIds.has(conv.id));
|
||||
onConfirm(selected);
|
||||
}
|
||||
|
||||
function handleCancel() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
|
||||
onCancel();
|
||||
}
|
||||
|
||||
let previousOpen = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
if (open && !previousOpen) {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
} else if (!open && previousOpen) {
|
||||
onCancel();
|
||||
}
|
||||
|
||||
previousOpen = open;
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay class="z-[1000000]" />
|
||||
|
||||
<Dialog.Content class="z-[1000001] max-w-2xl">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>
|
||||
Select Conversations to {mode === 'export' ? 'Export' : 'Import'}
|
||||
</Dialog.Title>
|
||||
|
||||
<Dialog.Description>
|
||||
{#if mode === 'export'}
|
||||
Choose which conversations you want to export. Selected conversations will be downloaded
|
||||
as a JSON file.
|
||||
{:else}
|
||||
Choose which conversations you want to import. Selected conversations will be merged
|
||||
with your existing conversations.
|
||||
{/if}
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="relative">
|
||||
<Search class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
|
||||
<Input bind:value={searchQuery} placeholder="Search conversations..." class="pr-9 pl-9" />
|
||||
|
||||
{#if searchQuery}
|
||||
<button
|
||||
class="absolute top-1/2 right-3 -translate-y-1/2 text-muted-foreground hover:text-foreground"
|
||||
onclick={() => (searchQuery = '')}
|
||||
type="button"
|
||||
>
|
||||
<X class="h-4 w-4" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-sm text-muted-foreground">
|
||||
<span>
|
||||
{selectedIds.size} of {conversations.length} selected
|
||||
{#if searchQuery}
|
||||
({filteredConversations.length} shown)
|
||||
{/if}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="overflow-hidden rounded-md border">
|
||||
<ScrollArea class="h-[400px]">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 z-10 bg-muted">
|
||||
<tr class="border-b">
|
||||
<th class="w-12 p-3 text-left">
|
||||
<Checkbox
|
||||
checked={allSelected}
|
||||
indeterminate={someSelected}
|
||||
onCheckedChange={toggleAll}
|
||||
/>
|
||||
</th>
|
||||
|
||||
<th class="p-3 text-left text-sm font-medium">Conversation Name</th>
|
||||
|
||||
<th class="w-32 p-3 text-left text-sm font-medium">Messages</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#if filteredConversations.length === 0}
|
||||
<tr>
|
||||
<td colspan="3" class="p-8 text-center text-sm text-muted-foreground">
|
||||
{#if searchQuery}
|
||||
No conversations found matching "{searchQuery}"
|
||||
{:else}
|
||||
No conversations available
|
||||
{/if}
|
||||
</td>
|
||||
</tr>
|
||||
{:else}
|
||||
{#each filteredConversations as conv (conv.id)}
|
||||
<tr
|
||||
class="cursor-pointer border-b transition-colors hover:bg-muted/50"
|
||||
onclick={(e) => toggleConversation(conv.id, e.shiftKey)}
|
||||
>
|
||||
<td class="p-3">
|
||||
<Checkbox
|
||||
checked={selectedIds.has(conv.id)}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
toggleConversation(conv.id, e.shiftKey);
|
||||
}}
|
||||
/>
|
||||
</td>
|
||||
|
||||
<td class="p-3 text-sm">
|
||||
<div
|
||||
class="max-w-[17rem] truncate"
|
||||
title={conv.name || 'Untitled conversation'}
|
||||
>
|
||||
{conv.name || 'Untitled conversation'}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="p-3 text-sm text-muted-foreground">
|
||||
{messageCountMap.get(conv.id) ?? 0}
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
{/if}
|
||||
</tbody>
|
||||
</table>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="outline" onclick={handleCancel}>Cancel</Button>
|
||||
|
||||
<Button onclick={handleConfirm} disabled={selectedIds.size === 0}>
|
||||
{mode === 'export' ? 'Export' : 'Import'} ({selectedIds.size})
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
<script lang="ts">
|
||||
import { Download, Upload } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ConversationSelectionDialog from './ConversationSelectionDialog.svelte';
|
||||
import { DatabaseStore } from '$lib/stores/database';
|
||||
import type { ExportedConversations } from '$lib/types/database';
|
||||
import { createMessageCountMap } from '$lib/utils/conversation-utils';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
|
||||
let exportedConversations = $state<DatabaseConversation[]>([]);
|
||||
let importedConversations = $state<DatabaseConversation[]>([]);
|
||||
let showExportSummary = $state(false);
|
||||
let showImportSummary = $state(false);
|
||||
|
||||
let showExportDialog = $state(false);
|
||||
let showImportDialog = $state(false);
|
||||
let availableConversations = $state<DatabaseConversation[]>([]);
|
||||
let messageCountMap = $state<Map<string, number>>(new Map());
|
||||
let fullImportData = $state<Array<{ conv: DatabaseConversation; messages: DatabaseMessage[] }>>(
|
||||
[]
|
||||
);
|
||||
|
||||
async function handleExportClick() {
|
||||
try {
|
||||
const allConversations = await DatabaseStore.getAllConversations();
|
||||
if (allConversations.length === 0) {
|
||||
alert('No conversations to export');
|
||||
return;
|
||||
}
|
||||
|
||||
const conversationsWithMessages = await Promise.all(
|
||||
allConversations.map(async (conv) => {
|
||||
const messages = await DatabaseStore.getConversationMessages(conv.id);
|
||||
return { conv, messages };
|
||||
})
|
||||
);
|
||||
|
||||
messageCountMap = createMessageCountMap(conversationsWithMessages);
|
||||
availableConversations = allConversations;
|
||||
showExportDialog = true;
|
||||
} catch (err) {
|
||||
console.error('Failed to load conversations:', err);
|
||||
alert('Failed to load conversations');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleExportConfirm(selectedConversations: DatabaseConversation[]) {
|
||||
try {
|
||||
const allData: ExportedConversations = await Promise.all(
|
||||
selectedConversations.map(async (conv) => {
|
||||
const messages = await DatabaseStore.getConversationMessages(conv.id);
|
||||
return { conv: $state.snapshot(conv), messages: $state.snapshot(messages) };
|
||||
})
|
||||
);
|
||||
|
||||
const blob = new Blob([JSON.stringify(allData, null, 2)], {
|
||||
type: 'application/json'
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
|
||||
a.href = url;
|
||||
a.download = `conversations_${new Date().toISOString().split('T')[0]}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
exportedConversations = selectedConversations;
|
||||
showExportSummary = true;
|
||||
showImportSummary = false;
|
||||
showExportDialog = false;
|
||||
} catch (err) {
|
||||
console.error('Export failed:', err);
|
||||
alert('Failed to export conversations');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleImportClick() {
|
||||
try {
|
||||
const input = document.createElement('input');
|
||||
|
||||
input.type = 'file';
|
||||
input.accept = '.json';
|
||||
|
||||
input.onchange = async (e) => {
|
||||
const file = (e.target as HTMLInputElement)?.files?.[0];
|
||||
if (!file) return;
|
||||
|
||||
try {
|
||||
const text = await file.text();
|
||||
const parsedData = JSON.parse(text);
|
||||
let importedData: ExportedConversations;
|
||||
|
||||
if (Array.isArray(parsedData)) {
|
||||
importedData = parsedData;
|
||||
} else if (
|
||||
parsedData &&
|
||||
typeof parsedData === 'object' &&
|
||||
'conv' in parsedData &&
|
||||
'messages' in parsedData
|
||||
) {
|
||||
// Single conversation object
|
||||
importedData = [parsedData];
|
||||
} else {
|
||||
throw new Error(
|
||||
'Invalid file format: expected array of conversations or single conversation object'
|
||||
);
|
||||
}
|
||||
|
||||
fullImportData = importedData;
|
||||
availableConversations = importedData.map(
|
||||
(item: { conv: DatabaseConversation; messages: DatabaseMessage[] }) => item.conv
|
||||
);
|
||||
messageCountMap = createMessageCountMap(importedData);
|
||||
showImportDialog = true;
|
||||
} catch (err: unknown) {
|
||||
const message = err instanceof Error ? err.message : 'Unknown error';
|
||||
|
||||
console.error('Failed to parse file:', err);
|
||||
alert(`Failed to parse file: ${message}`);
|
||||
}
|
||||
};
|
||||
|
||||
input.click();
|
||||
} catch (err) {
|
||||
console.error('Import failed:', err);
|
||||
alert('Failed to import conversations');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleImportConfirm(selectedConversations: DatabaseConversation[]) {
|
||||
try {
|
||||
const selectedIds = new Set(selectedConversations.map((c) => c.id));
|
||||
const selectedData = $state
|
||||
.snapshot(fullImportData)
|
||||
.filter((item) => selectedIds.has(item.conv.id));
|
||||
|
||||
await DatabaseStore.importConversations(selectedData);
|
||||
|
||||
await chatStore.loadConversations();
|
||||
|
||||
importedConversations = selectedConversations;
|
||||
showImportSummary = true;
|
||||
showExportSummary = false;
|
||||
showImportDialog = false;
|
||||
} catch (err) {
|
||||
console.error('Import failed:', err);
|
||||
alert('Failed to import conversations. Please check the file format.');
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="space-y-6">
|
||||
<div class="space-y-4">
|
||||
<div class="grid">
|
||||
<h4 class="mb-2 text-sm font-medium">Export Conversations</h4>
|
||||
|
||||
<p class="mb-4 text-sm text-muted-foreground">
|
||||
Download all your conversations as a JSON file. This includes all messages, attachments, and
|
||||
conversation history.
|
||||
</p>
|
||||
|
||||
<Button
|
||||
class="w-full justify-start justify-self-start md:w-auto"
|
||||
onclick={handleExportClick}
|
||||
variant="outline"
|
||||
>
|
||||
<Download class="mr-2 h-4 w-4" />
|
||||
|
||||
Export conversations
|
||||
</Button>
|
||||
|
||||
{#if showExportSummary && exportedConversations.length > 0}
|
||||
<div class="mt-4 grid overflow-x-auto rounded-lg border border-border/50 bg-muted/30 p-4">
|
||||
<h5 class="mb-2 text-sm font-medium">
|
||||
Exported {exportedConversations.length} conversation{exportedConversations.length === 1
|
||||
? ''
|
||||
: 's'}
|
||||
</h5>
|
||||
|
||||
<ul class="space-y-1 text-sm text-muted-foreground">
|
||||
{#each exportedConversations.slice(0, 10) as conv (conv.id)}
|
||||
<li class="truncate">• {conv.name || 'Untitled conversation'}</li>
|
||||
{/each}
|
||||
|
||||
{#if exportedConversations.length > 10}
|
||||
<li class="italic">
|
||||
... and {exportedConversations.length - 10} more
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="grid border-t border-border/30 pt-4">
|
||||
<h4 class="mb-2 text-sm font-medium">Import Conversations</h4>
|
||||
|
||||
<p class="mb-4 text-sm text-muted-foreground">
|
||||
Import one or more conversations from a previously exported JSON file. This will merge with
|
||||
your existing conversations.
|
||||
</p>
|
||||
|
||||
<Button
|
||||
class="w-full justify-start justify-self-start md:w-auto"
|
||||
onclick={handleImportClick}
|
||||
variant="outline"
|
||||
>
|
||||
<Upload class="mr-2 h-4 w-4" />
|
||||
Import conversations
|
||||
</Button>
|
||||
|
||||
{#if showImportSummary && importedConversations.length > 0}
|
||||
<div class="mt-4 grid overflow-x-auto rounded-lg border border-border/50 bg-muted/30 p-4">
|
||||
<h5 class="mb-2 text-sm font-medium">
|
||||
Imported {importedConversations.length} conversation{importedConversations.length === 1
|
||||
? ''
|
||||
: 's'}
|
||||
</h5>
|
||||
|
||||
<ul class="space-y-1 text-sm text-muted-foreground">
|
||||
{#each importedConversations.slice(0, 10) as conv (conv.id)}
|
||||
<li class="truncate">• {conv.name || 'Untitled conversation'}</li>
|
||||
{/each}
|
||||
|
||||
{#if importedConversations.length > 10}
|
||||
<li class="italic">
|
||||
... and {importedConversations.length - 10} more
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ConversationSelectionDialog
|
||||
conversations={availableConversations}
|
||||
{messageCountMap}
|
||||
mode="export"
|
||||
bind:open={showExportDialog}
|
||||
onCancel={() => (showExportDialog = false)}
|
||||
onConfirm={handleExportConfirm}
|
||||
/>
|
||||
|
||||
<ConversationSelectionDialog
|
||||
conversations={availableConversations}
|
||||
{messageCountMap}
|
||||
mode="import"
|
||||
bind:open={showImportDialog}
|
||||
onCancel={() => (showImportDialog = false)}
|
||||
onConfirm={handleImportConfirm}
|
||||
/>
|
||||
|
|
@ -1,9 +1,8 @@
|
|||
<script lang="ts">
|
||||
import { Search, SquarePen, X, Download, Upload } from '@lucide/svelte';
|
||||
import { Search, SquarePen, X } from '@lucide/svelte';
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { exportAllConversations, importConversations } from '$lib/stores/chat.svelte';
|
||||
|
||||
interface Props {
|
||||
handleMobileSidebarItemClick: () => void;
|
||||
|
|
@ -78,34 +77,5 @@
|
|||
|
||||
<KeyboardShortcutInfo keys={['cmd', 'k']} />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
class="w-full justify-start text-sm"
|
||||
onclick={() => {
|
||||
importConversations().catch((err) => {
|
||||
console.error('Import failed:', err);
|
||||
// Optional: show toast or dialog
|
||||
});
|
||||
}}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<Upload class="h-4 w-4" />
|
||||
Import conversations
|
||||
</div>
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
class="w-full justify-start text-sm"
|
||||
onclick={() => {
|
||||
exportAllConversations();
|
||||
}}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<Download class="h-4 w-4" />
|
||||
Export all conversations
|
||||
</div>
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { Trash2, Pencil, MoreHorizontal, Download } from '@lucide/svelte';
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2 } from '@lucide/svelte';
|
||||
import { ActionDropdown } from '$lib/components/app';
|
||||
import { downloadConversation } from '$lib/stores/chat.svelte';
|
||||
import { downloadConversation, getAllLoadingConversations } from '$lib/stores/chat.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -25,6 +25,8 @@
|
|||
let renderActionsDropdown = $state(false);
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let isLoading = $derived(getAllLoadingConversations().includes(conversation.id));
|
||||
|
||||
function handleEdit(event: Event) {
|
||||
event.stopPropagation();
|
||||
onEdit?.(conversation.id);
|
||||
|
|
@ -83,11 +85,16 @@
|
|||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
>
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||
{conversation.name}
|
||||
</span>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if isLoading}
|
||||
<Loader2 class="h-3.5 w-3.5 shrink-0 animate-spin text-muted-foreground" />
|
||||
{/if}
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||
{conversation.name}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{#if renderActionsDropdown}
|
||||
<div class="actions flex items-center">
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
|||
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte';
|
||||
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
|
||||
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
|
||||
export { default as ImportExportTab } from './chat/ChatSettings/ImportExportTab.svelte';
|
||||
export { default as ConversationSelectionDialog } from './chat/ChatSettings/ConversationSelectionDialog.svelte';
|
||||
export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte';
|
||||
|
||||
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
|
|
|
|||
|
|
@ -154,9 +154,20 @@
|
|||
return mutated ? tempDiv.innerHTML : html;
|
||||
}
|
||||
|
||||
function normalizeMathDelimiters(text: string): string {
|
||||
return text
|
||||
.replace(/(^|[^\\])\\\[((?:\\.|[\s\S])*?)\\\]/g, (_, prefix: string, content: string) => {
|
||||
return `${prefix}$$${content}$$`;
|
||||
})
|
||||
.replace(/(^|[^\\])\\\(((?:\\.|[\s\S])*?)\\\)/g, (_, prefix: string, content: string) => {
|
||||
return `${prefix}$${content}$`;
|
||||
});
|
||||
}
|
||||
|
||||
async function processMarkdown(text: string): Promise<string> {
|
||||
try {
|
||||
const result = await processor().process(text);
|
||||
const normalized = normalizeMathDelimiters(text);
|
||||
const result = await processor().process(normalized);
|
||||
const html = String(result);
|
||||
const enhancedLinks = enhanceLinks(html);
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import { slotsService } from './slots';
|
|||
* - Request lifecycle management (abort, cleanup)
|
||||
*/
|
||||
export class ChatService {
|
||||
private abortController: AbortController | null = null;
|
||||
private abortControllers: Map<string, AbortController> = new Map();
|
||||
|
||||
/**
|
||||
* Sends a chat completion request to the llama.cpp server.
|
||||
|
|
@ -43,7 +43,8 @@ export class ChatService {
|
|||
*/
|
||||
async sendMessage(
|
||||
messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[],
|
||||
options: SettingsChatServiceOptions = {}
|
||||
options: SettingsChatServiceOptions = {},
|
||||
conversationId?: string
|
||||
): Promise<string | void> {
|
||||
const {
|
||||
stream,
|
||||
|
|
@ -79,25 +80,25 @@ export class ChatService {
|
|||
|
||||
const currentConfig = config();
|
||||
|
||||
// Cancel any ongoing request and create a new abort controller
|
||||
this.abort();
|
||||
this.abortController = new AbortController();
|
||||
const requestId = conversationId || 'default';
|
||||
|
||||
if (this.abortControllers.has(requestId)) {
|
||||
this.abortControllers.get(requestId)?.abort();
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
this.abortControllers.set(requestId, abortController);
|
||||
|
||||
// Convert database messages with attachments to API format if needed
|
||||
const normalizedMessages: ApiChatMessageData[] = messages
|
||||
.map((msg) => {
|
||||
// Check if this is a DatabaseMessage by checking for DatabaseMessage-specific fields
|
||||
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
|
||||
// This is a DatabaseMessage, convert it
|
||||
const dbMsg = msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] };
|
||||
return ChatService.convertMessageToChatServiceData(dbMsg);
|
||||
} else {
|
||||
// This is already an ApiChatMessageData object
|
||||
return msg as ApiChatMessageData;
|
||||
}
|
||||
})
|
||||
.filter((msg) => {
|
||||
// Filter out empty system messages
|
||||
if (msg.role === 'system') {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
|
||||
|
|
@ -107,7 +108,6 @@ export class ChatService {
|
|||
return true;
|
||||
});
|
||||
|
||||
// Build base request body with system message injection
|
||||
const processedMessages = this.injectSystemMessage(normalizedMessages);
|
||||
|
||||
const requestBody: ApiChatCompletionRequest = {
|
||||
|
|
@ -172,11 +172,10 @@ export class ChatService {
|
|||
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {})
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: this.abortController.signal
|
||||
signal: abortController.signal
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
// Use the new parseErrorResponse method to handle structured errors
|
||||
const error = await this.parseErrorResponse(response);
|
||||
if (onError) {
|
||||
onError(error);
|
||||
|
|
@ -185,13 +184,16 @@ export class ChatService {
|
|||
}
|
||||
|
||||
if (stream) {
|
||||
return this.handleStreamResponse(
|
||||
await this.handleStreamResponse(
|
||||
response,
|
||||
onChunk,
|
||||
onComplete,
|
||||
onError,
|
||||
options.onReasoningChunk
|
||||
options.onReasoningChunk,
|
||||
conversationId,
|
||||
abortController.signal
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
return this.handleNonStreamResponse(response, onComplete, onError);
|
||||
}
|
||||
|
|
@ -227,18 +229,19 @@ export class ChatService {
|
|||
onError(userFriendlyError);
|
||||
}
|
||||
throw userFriendlyError;
|
||||
} finally {
|
||||
this.abortControllers.delete(requestId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles streaming response from the chat completion API.
|
||||
* Processes server-sent events and extracts content chunks from the stream.
|
||||
*
|
||||
* @param response - The fetch Response object containing the streaming data
|
||||
* Handles streaming response from the chat completion API
|
||||
* @param response - The Response object from the fetch request
|
||||
* @param onChunk - Optional callback invoked for each content chunk received
|
||||
* @param onComplete - Optional callback invoked when the stream is complete with full response
|
||||
* @param onError - Optional callback invoked if an error occurs during streaming
|
||||
* @param onReasoningChunk - Optional callback invoked for each reasoning content chunk
|
||||
* @param conversationId - Optional conversation ID for per-conversation state tracking
|
||||
* @returns {Promise<void>} Promise that resolves when streaming is complete
|
||||
* @throws {Error} if the stream cannot be read or parsed
|
||||
*/
|
||||
|
|
@ -251,7 +254,9 @@ export class ChatService {
|
|||
timings?: ChatMessageTimings
|
||||
) => void,
|
||||
onError?: (error: Error) => void,
|
||||
onReasoningChunk?: (chunk: string) => void
|
||||
onReasoningChunk?: (chunk: string) => void,
|
||||
conversationId?: string,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<void> {
|
||||
const reader = response.body?.getReader();
|
||||
|
||||
|
|
@ -269,14 +274,20 @@ export class ChatService {
|
|||
try {
|
||||
let chunk = '';
|
||||
while (true) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split('\n');
|
||||
chunk = lines.pop() || ''; // Save incomplete line for next read
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
|
|
@ -293,9 +304,7 @@ export class ChatService {
|
|||
const promptProgress = parsed.prompt_progress;
|
||||
|
||||
if (timings || promptProgress) {
|
||||
this.updateProcessingState(timings, promptProgress);
|
||||
|
||||
// Store the latest timing data
|
||||
this.updateProcessingState(timings, promptProgress, conversationId);
|
||||
if (timings) {
|
||||
lastTimings = timings;
|
||||
}
|
||||
|
|
@ -304,21 +313,29 @@ export class ChatService {
|
|||
if (content) {
|
||||
hasReceivedData = true;
|
||||
aggregatedContent += content;
|
||||
onChunk?.(content);
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
}
|
||||
}
|
||||
|
||||
if (reasoningContent) {
|
||||
hasReceivedData = true;
|
||||
fullReasoningContent += reasoningContent;
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON chunk:', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) break;
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) return;
|
||||
|
||||
if (streamFinished) {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
|
|
@ -445,6 +462,19 @@ export class ChatService {
|
|||
});
|
||||
}
|
||||
|
||||
// Handle legacy 'context' type from old webui (pasted content)
|
||||
const legacyContextFiles = message.extra.filter(
|
||||
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraLegacyContext =>
|
||||
extra.type === 'context'
|
||||
);
|
||||
|
||||
for (const legacyContextFile of legacyContextFiles) {
|
||||
contentParts.push({
|
||||
type: 'text',
|
||||
text: `\n\n--- File: ${legacyContextFile.name} ---\n${legacyContextFile.content}`
|
||||
});
|
||||
}
|
||||
|
||||
const audioFiles = message.extra.filter(
|
||||
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraAudioFile =>
|
||||
extra.type === 'audioFile'
|
||||
|
|
@ -520,10 +550,18 @@ export class ChatService {
|
|||
*
|
||||
* @public
|
||||
*/
|
||||
public abort(): void {
|
||||
if (this.abortController) {
|
||||
this.abortController.abort();
|
||||
this.abortController = null;
|
||||
public abort(conversationId?: string): void {
|
||||
if (conversationId) {
|
||||
const abortController = this.abortControllers.get(conversationId);
|
||||
if (abortController) {
|
||||
abortController.abort();
|
||||
this.abortControllers.delete(conversationId);
|
||||
}
|
||||
} else {
|
||||
for (const controller of this.abortControllers.values()) {
|
||||
controller.abort();
|
||||
}
|
||||
this.abortControllers.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -581,7 +619,6 @@ export class ChatService {
|
|||
|
||||
return error;
|
||||
} catch {
|
||||
// If we can't parse the error response, return a generic error
|
||||
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
fallback.name = 'HttpError';
|
||||
return fallback;
|
||||
|
|
@ -590,23 +627,25 @@ export class ChatService {
|
|||
|
||||
private updateProcessingState(
|
||||
timings?: ChatMessageTimings,
|
||||
promptProgress?: ChatMessagePromptProgress
|
||||
promptProgress?: ChatMessagePromptProgress,
|
||||
conversationId?: string
|
||||
): void {
|
||||
// Calculate tokens per second from timing data
|
||||
const tokensPerSecond =
|
||||
timings?.predicted_ms && timings?.predicted_n
|
||||
? (timings.predicted_n / timings.predicted_ms) * 1000
|
||||
: 0;
|
||||
|
||||
// Update slots service with timing data (async but don't wait)
|
||||
slotsService
|
||||
.updateFromTimingData({
|
||||
prompt_n: timings?.prompt_n || 0,
|
||||
predicted_n: timings?.predicted_n || 0,
|
||||
predicted_per_second: tokensPerSecond,
|
||||
cache_n: timings?.cache_n || 0,
|
||||
prompt_progress: promptProgress
|
||||
})
|
||||
.updateFromTimingData(
|
||||
{
|
||||
prompt_n: timings?.prompt_n || 0,
|
||||
predicted_n: timings?.predicted_n || 0,
|
||||
predicted_per_second: tokensPerSecond,
|
||||
cache_n: timings?.cache_n || 0,
|
||||
prompt_progress: promptProgress
|
||||
},
|
||||
conversationId
|
||||
)
|
||||
.catch((error) => {
|
||||
console.warn('Failed to update processing state:', error);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ export class SlotsService {
|
|||
private callbacks: Set<(state: ApiProcessingState | null) => void> = new Set();
|
||||
private isStreamingActive: boolean = false;
|
||||
private lastKnownState: ApiProcessingState | null = null;
|
||||
private conversationStates: Map<string, ApiProcessingState | null> = new Map();
|
||||
private activeConversationId: string | null = null;
|
||||
|
||||
/**
|
||||
* Start streaming session tracking
|
||||
|
|
@ -75,6 +77,62 @@ export class SlotsService {
|
|||
return this.isStreamingActive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the active conversation for statistics display
|
||||
*/
|
||||
setActiveConversation(conversationId: string | null): void {
|
||||
this.activeConversationId = conversationId;
|
||||
this.notifyCallbacks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Update processing state for a specific conversation
|
||||
*/
|
||||
updateConversationState(conversationId: string, state: ApiProcessingState | null): void {
|
||||
this.conversationStates.set(conversationId, state);
|
||||
|
||||
if (conversationId === this.activeConversationId) {
|
||||
this.lastKnownState = state;
|
||||
this.notifyCallbacks();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get processing state for a specific conversation
|
||||
*/
|
||||
getConversationState(conversationId: string): ApiProcessingState | null {
|
||||
return this.conversationStates.get(conversationId) || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear state for a specific conversation
|
||||
*/
|
||||
clearConversationState(conversationId: string): void {
|
||||
this.conversationStates.delete(conversationId);
|
||||
|
||||
if (conversationId === this.activeConversationId) {
|
||||
this.lastKnownState = null;
|
||||
this.notifyCallbacks();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify all callbacks with current state
|
||||
*/
|
||||
private notifyCallbacks(): void {
|
||||
const currentState = this.activeConversationId
|
||||
? this.conversationStates.get(this.activeConversationId) || null
|
||||
: this.lastKnownState;
|
||||
|
||||
for (const callback of this.callbacks) {
|
||||
try {
|
||||
callback(currentState);
|
||||
} catch (error) {
|
||||
console.error('Error in slots service callback:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Polling is no longer used - timing data comes from ChatService streaming response
|
||||
* This method logs a warning if called to help identify outdated usage
|
||||
|
|
@ -100,29 +158,29 @@ export class SlotsService {
|
|||
/**
|
||||
* Updates processing state with timing data from ChatService streaming response
|
||||
*/
|
||||
async updateFromTimingData(timingData: {
|
||||
prompt_n: number;
|
||||
predicted_n: number;
|
||||
predicted_per_second: number;
|
||||
cache_n: number;
|
||||
prompt_progress?: ChatMessagePromptProgress;
|
||||
}): Promise<void> {
|
||||
async updateFromTimingData(
|
||||
timingData: {
|
||||
prompt_n: number;
|
||||
predicted_n: number;
|
||||
predicted_per_second: number;
|
||||
cache_n: number;
|
||||
prompt_progress?: ChatMessagePromptProgress;
|
||||
},
|
||||
conversationId?: string
|
||||
): Promise<void> {
|
||||
const processingState = await this.parseCompletionTimingData(timingData);
|
||||
|
||||
// Only update if we successfully parsed the state
|
||||
if (processingState === null) {
|
||||
console.warn('Failed to parse timing data - skipping update');
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
this.lastKnownState = processingState;
|
||||
|
||||
for (const callback of this.callbacks) {
|
||||
try {
|
||||
callback(processingState);
|
||||
} catch (error) {
|
||||
console.error('Error in timing callback:', error);
|
||||
}
|
||||
if (conversationId) {
|
||||
this.updateConversationState(conversationId, processingState);
|
||||
} else {
|
||||
this.lastKnownState = processingState;
|
||||
this.notifyCallbacks();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -143,6 +201,7 @@ export class SlotsService {
|
|||
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {})
|
||||
}
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const slotsData = await response.json();
|
||||
if (Array.isArray(slotsData) && slotsData.length > 0) {
|
||||
|
|
@ -179,6 +238,7 @@ export class SlotsService {
|
|||
|
||||
if (contextTotal === null) {
|
||||
console.warn('No context total available - cannot calculate processing state');
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
@ -214,13 +274,21 @@ export class SlotsService {
|
|||
/**
|
||||
* Get current processing state
|
||||
* Returns the last known state from timing data, or null if no data available
|
||||
* If activeConversationId is set, returns state for that conversation
|
||||
*/
|
||||
async getCurrentState(): Promise<ApiProcessingState | null> {
|
||||
if (this.activeConversationId) {
|
||||
const conversationState = this.conversationStates.get(this.activeConversationId);
|
||||
|
||||
if (conversationState) {
|
||||
return conversationState;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.lastKnownState) {
|
||||
return this.lastKnownState;
|
||||
}
|
||||
try {
|
||||
// Import dynamically to avoid circular dependency
|
||||
const { chatStore } = await import('$lib/stores/chat.svelte');
|
||||
const messages = chatStore.activeMessages;
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/u
|
|||
import { browser } from '$app/environment';
|
||||
import { goto } from '$app/navigation';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import type { ExportedConversations } from '$lib/types/database';
|
||||
|
||||
/**
|
||||
|
|
@ -50,6 +51,8 @@ class ChatStore {
|
|||
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
|
||||
isInitialized = $state(false);
|
||||
isLoading = $state(false);
|
||||
conversationLoadingStates = new SvelteMap<string, boolean>();
|
||||
conversationStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
|
||||
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
|
||||
|
||||
constructor() {
|
||||
|
|
@ -94,6 +97,13 @@ class ChatStore {
|
|||
this.activeConversation = conversation;
|
||||
this.activeMessages = [];
|
||||
|
||||
slotsService.setActiveConversation(conversation.id);
|
||||
|
||||
const isConvLoading = this.isConversationLoading(conversation.id);
|
||||
this.isLoading = isConvLoading;
|
||||
|
||||
this.currentResponse = '';
|
||||
|
||||
await goto(`#/chat/${conversation.id}`);
|
||||
|
||||
return conversation.id;
|
||||
|
|
@ -114,6 +124,14 @@ class ChatStore {
|
|||
|
||||
this.activeConversation = conversation;
|
||||
|
||||
slotsService.setActiveConversation(convId);
|
||||
|
||||
const isConvLoading = this.isConversationLoading(convId);
|
||||
this.isLoading = isConvLoading;
|
||||
|
||||
const streamingState = this.getConversationStreaming(convId);
|
||||
this.currentResponse = streamingState?.response || '';
|
||||
|
||||
if (conversation.currNode) {
|
||||
const allMessages = await DatabaseStore.getConversationMessages(convId);
|
||||
this.activeMessages = filterByLeafNodeId(
|
||||
|
|
@ -285,6 +303,47 @@ class ChatStore {
|
|||
return apiOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper methods for per-conversation loading state management
|
||||
*/
|
||||
private setConversationLoading(convId: string, loading: boolean): void {
|
||||
if (loading) {
|
||||
this.conversationLoadingStates.set(convId, true);
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.isLoading = true;
|
||||
}
|
||||
} else {
|
||||
this.conversationLoadingStates.delete(convId);
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.isLoading = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private isConversationLoading(convId: string): boolean {
|
||||
return this.conversationLoadingStates.get(convId) || false;
|
||||
}
|
||||
|
||||
private setConversationStreaming(convId: string, response: string, messageId: string): void {
|
||||
this.conversationStreamingStates.set(convId, { response, messageId });
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.currentResponse = response;
|
||||
}
|
||||
}
|
||||
|
||||
private clearConversationStreaming(convId: string): void {
|
||||
this.conversationStreamingStates.delete(convId);
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.currentResponse = '';
|
||||
}
|
||||
}
|
||||
|
||||
private getConversationStreaming(
|
||||
convId: string
|
||||
): { response: string; messageId: string } | undefined {
|
||||
return this.conversationStreamingStates.get(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles streaming chat completion with the AI model
|
||||
* @param allMessages - All messages in the conversation
|
||||
|
|
@ -325,125 +384,132 @@ class ChatStore {
|
|||
};
|
||||
|
||||
slotsService.startStreaming();
|
||||
slotsService.setActiveConversation(assistantMessage.convId);
|
||||
|
||||
await chatService.sendMessage(allMessages, {
|
||||
...this.getApiOptions(),
|
||||
await chatService.sendMessage(
|
||||
allMessages,
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
|
||||
onChunk: (chunk: string) => {
|
||||
streamedContent += chunk;
|
||||
this.currentResponse = streamedContent;
|
||||
onChunk: (chunk: string) => {
|
||||
streamedContent += chunk;
|
||||
this.setConversationStreaming(
|
||||
assistantMessage.convId,
|
||||
streamedContent,
|
||||
assistantMessage.id
|
||||
);
|
||||
|
||||
captureModelIfNeeded();
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: streamedContent
|
||||
});
|
||||
},
|
||||
captureModelIfNeeded();
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: streamedContent
|
||||
});
|
||||
},
|
||||
|
||||
onReasoningChunk: (reasoningChunk: string) => {
|
||||
streamedReasoningContent += reasoningChunk;
|
||||
onReasoningChunk: (reasoningChunk: string) => {
|
||||
streamedReasoningContent += reasoningChunk;
|
||||
|
||||
captureModelIfNeeded();
|
||||
captureModelIfNeeded();
|
||||
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent });
|
||||
},
|
||||
this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent });
|
||||
},
|
||||
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
) => {
|
||||
slotsService.stopStreaming();
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
) => {
|
||||
slotsService.stopStreaming();
|
||||
|
||||
const updateData: {
|
||||
content: string;
|
||||
thinking: string;
|
||||
timings?: ChatMessageTimings;
|
||||
model?: string;
|
||||
} = {
|
||||
content: finalContent || streamedContent,
|
||||
thinking: reasoningContent || streamedReasoningContent,
|
||||
timings: timings
|
||||
};
|
||||
const updateData: {
|
||||
content: string;
|
||||
thinking: string;
|
||||
timings?: ChatMessageTimings;
|
||||
model?: string;
|
||||
} = {
|
||||
content: finalContent || streamedContent,
|
||||
thinking: reasoningContent || streamedReasoningContent,
|
||||
timings: timings
|
||||
};
|
||||
|
||||
const capturedModel = captureModelIfNeeded(false);
|
||||
const capturedModel = captureModelIfNeeded(false);
|
||||
|
||||
if (capturedModel) {
|
||||
updateData.model = capturedModel;
|
||||
}
|
||||
if (capturedModel) {
|
||||
updateData.model = capturedModel;
|
||||
}
|
||||
|
||||
await DatabaseStore.updateMessage(assistantMessage.id, updateData);
|
||||
await DatabaseStore.updateMessage(assistantMessage.id, updateData);
|
||||
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
|
||||
const localUpdateData: { timings?: ChatMessageTimings; model?: string } = {
|
||||
timings: timings
|
||||
};
|
||||
const localUpdateData: { timings?: ChatMessageTimings; model?: string } = {
|
||||
timings: timings
|
||||
};
|
||||
|
||||
if (updateData.model) {
|
||||
localUpdateData.model = updateData.model;
|
||||
}
|
||||
if (updateData.model) {
|
||||
localUpdateData.model = updateData.model;
|
||||
}
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, localUpdateData);
|
||||
this.updateMessageAtIndex(messageIndex, localUpdateData);
|
||||
|
||||
await DatabaseStore.updateCurrentNode(this.activeConversation!.id, assistantMessage.id);
|
||||
this.activeConversation!.currNode = assistantMessage.id;
|
||||
await this.refreshActiveMessages();
|
||||
await DatabaseStore.updateCurrentNode(assistantMessage.convId, assistantMessage.id);
|
||||
|
||||
if (onComplete) {
|
||||
await onComplete(streamedContent);
|
||||
}
|
||||
if (this.activeConversation?.id === assistantMessage.convId) {
|
||||
this.activeConversation.currNode = assistantMessage.id;
|
||||
await this.refreshActiveMessages();
|
||||
}
|
||||
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
},
|
||||
if (onComplete) {
|
||||
await onComplete(streamedContent);
|
||||
}
|
||||
|
||||
onError: (error: Error) => {
|
||||
slotsService.stopStreaming();
|
||||
this.setConversationLoading(assistantMessage.convId, false);
|
||||
this.clearConversationStreaming(assistantMessage.convId);
|
||||
slotsService.clearConversationState(assistantMessage.convId);
|
||||
},
|
||||
|
||||
if (error.name === 'AbortError' || error instanceof DOMException) {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
return;
|
||||
}
|
||||
onError: (error: Error) => {
|
||||
slotsService.stopStreaming();
|
||||
|
||||
console.error('Streaming error:', error);
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
if (this.isAbortError(error)) {
|
||||
this.setConversationLoading(assistantMessage.convId, false);
|
||||
this.clearConversationStreaming(assistantMessage.convId);
|
||||
slotsService.clearConversationState(assistantMessage.convId);
|
||||
return;
|
||||
}
|
||||
|
||||
const messageIndex = this.activeMessages.findIndex(
|
||||
(m: DatabaseMessage) => m.id === assistantMessage.id
|
||||
);
|
||||
console.error('Streaming error:', error);
|
||||
this.setConversationLoading(assistantMessage.convId, false);
|
||||
this.clearConversationStreaming(assistantMessage.convId);
|
||||
slotsService.clearConversationState(assistantMessage.convId);
|
||||
|
||||
if (messageIndex !== -1) {
|
||||
const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
|
||||
const messageIndex = this.activeMessages.findIndex(
|
||||
(m: DatabaseMessage) => m.id === assistantMessage.id
|
||||
);
|
||||
|
||||
if (failedMessage) {
|
||||
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
|
||||
console.error('Failed to remove assistant message after error:', cleanupError);
|
||||
});
|
||||
if (messageIndex !== -1) {
|
||||
const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
|
||||
|
||||
if (failedMessage) {
|
||||
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
|
||||
console.error('Failed to remove assistant message after error:', cleanupError);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
|
||||
if (onError) {
|
||||
onError(error);
|
||||
}
|
||||
}
|
||||
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
|
||||
if (onError) {
|
||||
onError(error);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
|
||||
this.errorDialogState = { type, message };
|
||||
}
|
||||
|
||||
dismissErrorDialog(): void {
|
||||
this.errorDialogState = null;
|
||||
},
|
||||
assistantMessage.convId
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -455,6 +521,14 @@ class ChatStore {
|
|||
return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException);
|
||||
}
|
||||
|
||||
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
|
||||
this.errorDialogState = { type, message };
|
||||
}
|
||||
|
||||
dismissErrorDialog(): void {
|
||||
this.errorDialogState = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the index of a message in the active messages array
|
||||
* @param messageId - The message ID to find
|
||||
|
|
@ -519,7 +593,12 @@ class ChatStore {
|
|||
* @param extras - Optional extra data (files, attachments, etc.)
|
||||
*/
|
||||
async sendMessage(content: string, extras?: DatabaseMessageExtra[]): Promise<void> {
|
||||
if ((!content.trim() && (!extras || extras.length === 0)) || this.isLoading) return;
|
||||
if (!content.trim() && (!extras || extras.length === 0)) return;
|
||||
|
||||
if (this.activeConversation && this.isConversationLoading(this.activeConversation.id)) {
|
||||
console.log('Cannot send message: current conversation is already processing a message');
|
||||
return;
|
||||
}
|
||||
|
||||
let isNewConversation = false;
|
||||
|
||||
|
|
@ -534,8 +613,9 @@ class ChatStore {
|
|||
}
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
let userMessage: DatabaseMessage | null = null;
|
||||
|
||||
|
|
@ -546,7 +626,6 @@ class ChatStore {
|
|||
throw new Error('Failed to add user message');
|
||||
}
|
||||
|
||||
// If this is a new conversation, update the title with the first user prompt
|
||||
if (isNewConversation && content) {
|
||||
const title = content.trim();
|
||||
await this.updateConversationName(this.activeConversation.id, title);
|
||||
|
|
@ -559,19 +638,18 @@ class ChatStore {
|
|||
}
|
||||
|
||||
this.activeMessages.push(assistantMessage);
|
||||
// Don't update currNode until after streaming completes to maintain proper conversation path
|
||||
|
||||
const conversationContext = this.activeMessages.slice(0, -1);
|
||||
|
||||
await this.streamChatCompletion(conversationContext, assistantMessage);
|
||||
} catch (error) {
|
||||
if (this.isAbortError(error)) {
|
||||
this.isLoading = false;
|
||||
this.setConversationLoading(this.activeConversation!.id, false);
|
||||
return;
|
||||
}
|
||||
|
||||
console.error('Failed to send message:', error);
|
||||
this.isLoading = false;
|
||||
this.setConversationLoading(this.activeConversation!.id, false);
|
||||
if (!this.errorDialogState) {
|
||||
if (error instanceof Error) {
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
|
|
@ -587,12 +665,19 @@ class ChatStore {
|
|||
* Stops the current message generation
|
||||
* Aborts ongoing requests and saves partial response if available
|
||||
*/
|
||||
stopGeneration(): void {
|
||||
async stopGeneration(): Promise<void> {
|
||||
if (!this.activeConversation) return;
|
||||
|
||||
const convId = this.activeConversation.id;
|
||||
|
||||
await this.savePartialResponseIfNeeded(convId);
|
||||
|
||||
slotsService.stopStreaming();
|
||||
chatService.abort();
|
||||
this.savePartialResponseIfNeeded();
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
chatService.abort(convId);
|
||||
|
||||
this.setConversationLoading(convId, false);
|
||||
this.clearConversationStreaming(convId);
|
||||
slotsService.clearConversationState(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -604,6 +689,9 @@ class ChatStore {
|
|||
slotsService.stopStreaming();
|
||||
chatService.abort();
|
||||
await this.savePartialResponseIfNeeded();
|
||||
|
||||
this.conversationLoadingStates.clear();
|
||||
this.conversationStreamingStates.clear();
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
}
|
||||
|
|
@ -612,12 +700,23 @@ class ChatStore {
|
|||
* Saves partial response if generation was interrupted
|
||||
* Preserves user's partial content and timing data when generation is stopped early
|
||||
*/
|
||||
private async savePartialResponseIfNeeded(): Promise<void> {
|
||||
if (!this.currentResponse.trim() || !this.activeMessages.length) {
|
||||
private async savePartialResponseIfNeeded(convId?: string): Promise<void> {
|
||||
const conversationId = convId || this.activeConversation?.id;
|
||||
if (!conversationId) return;
|
||||
|
||||
const streamingState = this.conversationStreamingStates.get(conversationId);
|
||||
if (!streamingState || !streamingState.response.trim()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const lastMessage = this.activeMessages[this.activeMessages.length - 1];
|
||||
const messages =
|
||||
conversationId === this.activeConversation?.id
|
||||
? this.activeMessages
|
||||
: await DatabaseStore.getConversationMessages(conversationId);
|
||||
|
||||
if (!messages.length) return;
|
||||
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
if (lastMessage && lastMessage.role === 'assistant') {
|
||||
try {
|
||||
|
|
@ -626,7 +725,7 @@ class ChatStore {
|
|||
thinking?: string;
|
||||
timings?: ChatMessageTimings;
|
||||
} = {
|
||||
content: this.currentResponse
|
||||
content: streamingState.response
|
||||
};
|
||||
|
||||
if (lastMessage.thinking?.trim()) {
|
||||
|
|
@ -640,7 +739,6 @@ class ChatStore {
|
|||
prompt_n: lastKnownState.promptTokens || 0,
|
||||
predicted_n: lastKnownState.tokensDecoded || 0,
|
||||
cache_n: lastKnownState.cacheTokens || 0,
|
||||
// We don't have ms data from the state, but we can estimate
|
||||
predicted_ms:
|
||||
lastKnownState.tokensPerSecond && lastKnownState.tokensDecoded
|
||||
? (lastKnownState.tokensDecoded / lastKnownState.tokensPerSecond) * 1000
|
||||
|
|
@ -701,7 +799,6 @@ class ChatStore {
|
|||
this.updateMessageAtIndex(messageIndex, { content: newContent });
|
||||
await DatabaseStore.updateMessage(messageId, { content: newContent });
|
||||
|
||||
// If this is the first user message, update the conversation title with confirmation if needed
|
||||
if (isFirstUserMessage && newContent.trim()) {
|
||||
await this.updateConversationTitleWithConfirmation(
|
||||
this.activeConversation.id,
|
||||
|
|
@ -718,8 +815,8 @@ class ChatStore {
|
|||
this.activeMessages = this.activeMessages.slice(0, messageIndex + 1);
|
||||
this.updateConversationTimestamp();
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
try {
|
||||
const assistantMessage = await this.createAssistantMessage();
|
||||
|
|
@ -742,7 +839,7 @@ class ChatStore {
|
|||
);
|
||||
} catch (regenerateError) {
|
||||
console.error('Failed to regenerate response:', regenerateError);
|
||||
this.isLoading = false;
|
||||
this.setConversationLoading(this.activeConversation!.id, false);
|
||||
|
||||
const messageIndex = this.findMessageIndex(messageId);
|
||||
this.updateMessageAtIndex(messageIndex, { content: originalContent });
|
||||
|
|
@ -784,8 +881,8 @@ class ChatStore {
|
|||
this.activeMessages = this.activeMessages.slice(0, messageIndex);
|
||||
this.updateConversationTimestamp();
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
try {
|
||||
const parentMessageId =
|
||||
|
|
@ -806,7 +903,7 @@ class ChatStore {
|
|||
await this.streamChatCompletion(conversationContext, assistantMessage);
|
||||
} catch (regenerateError) {
|
||||
console.error('Failed to regenerate response:', regenerateError);
|
||||
this.isLoading = false;
|
||||
this.setConversationLoading(this.activeConversation!.id, false);
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.isAbortError(error)) return;
|
||||
|
|
@ -862,7 +959,6 @@ class ChatStore {
|
|||
try {
|
||||
const currentConfig = config();
|
||||
|
||||
// Only ask for confirmation if the setting is enabled and callback is provided
|
||||
if (currentConfig.askForTitleConfirmation && onConfirmationNeeded) {
|
||||
const conversation = await DatabaseStore.getConversation(convId);
|
||||
if (!conversation) return false;
|
||||
|
|
@ -944,8 +1040,9 @@ class ChatStore {
|
|||
|
||||
/**
|
||||
* Exports all conversations with their messages as a JSON file
|
||||
* Returns the list of exported conversations
|
||||
*/
|
||||
async exportAllConversations(): Promise<void> {
|
||||
async exportAllConversations(): Promise<DatabaseConversation[]> {
|
||||
try {
|
||||
const allConversations = await DatabaseStore.getAllConversations();
|
||||
if (allConversations.length === 0) {
|
||||
|
|
@ -972,6 +1069,7 @@ class ChatStore {
|
|||
URL.revokeObjectURL(url);
|
||||
|
||||
toast.success(`All conversations (${allConversations.length}) prepared for download`);
|
||||
return allConversations;
|
||||
} catch (err) {
|
||||
console.error('Failed to export conversations:', err);
|
||||
throw err;
|
||||
|
|
@ -982,8 +1080,9 @@ class ChatStore {
|
|||
* Imports conversations from a JSON file.
|
||||
* Supports both single conversation (object) and multiple conversations (array).
|
||||
* Uses DatabaseStore for safe, encapsulated data access
|
||||
* Returns the list of imported conversations
|
||||
*/
|
||||
async importConversations(): Promise<void> {
|
||||
async importConversations(): Promise<DatabaseConversation[]> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
|
|
@ -1024,7 +1123,9 @@ class ChatStore {
|
|||
|
||||
toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`);
|
||||
|
||||
resolve(undefined);
|
||||
// Extract the conversation objects from imported data
|
||||
const importedConversations = importedData.map((item) => item.conv);
|
||||
resolve(importedConversations);
|
||||
} catch (err: unknown) {
|
||||
const message = err instanceof Error ? err.message : 'Unknown error';
|
||||
console.error('Failed to import conversations:', err);
|
||||
|
|
@ -1170,14 +1271,16 @@ class ChatStore {
|
|||
}
|
||||
|
||||
/**
|
||||
* Clears the active conversation and resets state
|
||||
* Clears the active conversation and messages
|
||||
* Used when navigating away from chat or starting fresh
|
||||
* Note: Does not stop ongoing streaming to allow background completion
|
||||
*/
|
||||
clearActiveConversation(): void {
|
||||
this.activeConversation = null;
|
||||
this.activeMessages = [];
|
||||
this.currentResponse = '';
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
slotsService.setActiveConversation(null);
|
||||
}
|
||||
|
||||
/** Refreshes active messages based on currNode after branch navigation */
|
||||
|
|
@ -1419,8 +1522,8 @@ class ChatStore {
|
|||
return;
|
||||
}
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
const newAssistantMessage = await DatabaseStore.createMessageBranch(
|
||||
{
|
||||
|
|
@ -1454,7 +1557,7 @@ class ChatStore {
|
|||
if (this.isAbortError(error)) return;
|
||||
|
||||
console.error('Failed to regenerate message with branching:', error);
|
||||
this.isLoading = false;
|
||||
this.setConversationLoading(this.activeConversation!.id, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1466,8 +1569,8 @@ class ChatStore {
|
|||
if (!this.activeConversation) return;
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
try {
|
||||
// Get conversation path up to the user message
|
||||
|
|
@ -1499,9 +1602,30 @@ class ChatStore {
|
|||
await this.streamChatCompletion(conversationPath, assistantMessage);
|
||||
} catch (error) {
|
||||
console.error('Failed to generate response:', error);
|
||||
this.isLoading = false;
|
||||
this.setConversationLoading(this.activeConversation!.id, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Public methods for accessing per-conversation states
|
||||
*/
|
||||
public isConversationLoadingPublic(convId: string): boolean {
|
||||
return this.isConversationLoading(convId);
|
||||
}
|
||||
|
||||
public getConversationStreamingPublic(
|
||||
convId: string
|
||||
): { response: string; messageId: string } | undefined {
|
||||
return this.getConversationStreaming(convId);
|
||||
}
|
||||
|
||||
public getAllLoadingConversations(): string[] {
|
||||
return Array.from(this.conversationLoadingStates.keys());
|
||||
}
|
||||
|
||||
public getAllStreamingConversations(): string[] {
|
||||
return Array.from(this.conversationStreamingStates.keys());
|
||||
}
|
||||
}
|
||||
|
||||
export const chatStore = new ChatStore();
|
||||
|
|
@ -1541,3 +1665,11 @@ export function stopGeneration() {
|
|||
chatStore.stopGeneration();
|
||||
}
|
||||
export const messages = () => chatStore.activeMessages;
|
||||
|
||||
// Per-conversation state access
|
||||
export const isConversationLoading = (convId: string) =>
|
||||
chatStore.isConversationLoadingPublic(convId);
|
||||
export const getConversationStreaming = (convId: string) =>
|
||||
chatStore.getConversationStreamingPublic(convId);
|
||||
export const getAllLoadingConversations = () => chatStore.getAllLoadingConversations();
|
||||
export const getAllStreamingConversations = () => chatStore.getAllStreamingConversations();
|
||||
|
|
|
|||
|
|
@ -34,11 +34,22 @@ export interface DatabaseMessageExtraPdfFile {
|
|||
processedAsImages: boolean; // Whether PDF was processed as images
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy format from old webui - pasted content was stored as "context" type
|
||||
* @deprecated Use DatabaseMessageExtraTextFile instead
|
||||
*/
|
||||
export interface DatabaseMessageExtraLegacyContext {
|
||||
type: 'context';
|
||||
name: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export type DatabaseMessageExtra =
|
||||
| DatabaseMessageExtraImageFile
|
||||
| DatabaseMessageExtraTextFile
|
||||
| DatabaseMessageExtraAudioFile
|
||||
| DatabaseMessageExtraPdfFile;
|
||||
| DatabaseMessageExtraPdfFile
|
||||
| DatabaseMessageExtraLegacyContext;
|
||||
|
||||
export interface DatabaseMessage {
|
||||
id: string;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Utility functions for conversation data manipulation
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a map of conversation IDs to their message counts from exported conversation data
|
||||
* @param exportedData - Array of exported conversations with their messages
|
||||
* @returns Map of conversation ID to message count
|
||||
*/
|
||||
export function createMessageCountMap(
|
||||
exportedData: Array<{ conv: DatabaseConversation; messages: DatabaseMessage[] }>
|
||||
): Map<string, number> {
|
||||
const countMap = new Map<string, number>();
|
||||
|
||||
for (const item of exportedData) {
|
||||
countMap.set(item.conv.id, item.messages.length);
|
||||
}
|
||||
|
||||
return countMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the message count for a specific conversation from the count map
|
||||
* @param conversationId - The ID of the conversation
|
||||
* @param countMap - Map of conversation IDs to message counts
|
||||
* @returns The message count, or 0 if not found
|
||||
*/
|
||||
export function getMessageCount(conversationId: string, countMap: Map<string, number>): number {
|
||||
return countMap.get(conversationId) ?? 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
export function isIMEComposing(event: KeyboardEvent) {
|
||||
// Check for IME composition using isComposing property and keyCode 229 (specifically for IME composition on Safari, which is notorious for not supporting KeyboardEvent.isComposing)
|
||||
// This prevents form submission when confirming IME word selection (e.g., Japanese/Chinese input)
|
||||
return event.isComposing || event.keyCode === 229;
|
||||
}
|
||||
|
|
@ -1,45 +1,26 @@
|
|||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { beforeNavigate } from '$app/navigation';
|
||||
import { ChatScreen } from '$lib/components/app';
|
||||
import {
|
||||
chatStore,
|
||||
activeConversation,
|
||||
isLoading,
|
||||
stopGeneration,
|
||||
gracefulStop
|
||||
stopGeneration
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { onDestroy } from 'svelte';
|
||||
|
||||
let chatId = $derived(page.params.id);
|
||||
let currentChatId: string | undefined = undefined;
|
||||
|
||||
beforeNavigate(async ({ cancel, to }) => {
|
||||
if (isLoading()) {
|
||||
console.log(
|
||||
'Navigation detected while streaming - aborting stream and saving partial response'
|
||||
);
|
||||
|
||||
cancel();
|
||||
|
||||
await gracefulStop();
|
||||
|
||||
if (to?.url) {
|
||||
await goto(to.url.pathname + to.url.search + to.url.hash);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (chatId && chatId !== currentChatId) {
|
||||
if (isLoading()) {
|
||||
console.log('Chat switch detected while streaming - aborting stream');
|
||||
stopGeneration();
|
||||
}
|
||||
|
||||
currentChatId = chatId;
|
||||
|
||||
// Skip loading if this conversation is already active (e.g., just created)
|
||||
if (activeConversation()?.id === chatId) {
|
||||
return;
|
||||
}
|
||||
|
||||
(async () => {
|
||||
const success = await chatStore.loadConversation(chatId);
|
||||
|
||||
|
|
@ -66,12 +47,6 @@
|
|||
};
|
||||
}
|
||||
});
|
||||
|
||||
onDestroy(() => {
|
||||
if (isLoading()) {
|
||||
stopGeneration();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ const config = {
|
|||
// Consult https://svelte.dev/docs/kit/integrations
|
||||
// for more information about preprocessors
|
||||
preprocess: [vitePreprocess(), mdsvex()],
|
||||
|
||||
kit: {
|
||||
paths: {
|
||||
relative: true
|
||||
|
|
@ -23,6 +24,7 @@ const config = {
|
|||
bundleStrategy: 'inline'
|
||||
}
|
||||
},
|
||||
|
||||
extensions: ['.svelte', '.svx']
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,12 @@ function llamaCppBuildPlugin() {
|
|||
}
|
||||
|
||||
export default defineConfig({
|
||||
build: {
|
||||
chunkSizeWarningLimit: 3072
|
||||
},
|
||||
|
||||
plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()],
|
||||
|
||||
test: {
|
||||
projects: [
|
||||
{
|
||||
|
|
@ -123,6 +128,7 @@ export default defineConfig({
|
|||
}
|
||||
]
|
||||
},
|
||||
|
||||
server: {
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:8080',
|
||||
|
|
|
|||
Loading…
Reference in New Issue