Merge branch 'ggml-org:master' into master
This commit is contained in:
commit
c0f86114f9
|
|
@ -16,7 +16,7 @@ The project differentiates between 3 levels of contributors:
|
|||
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
|
||||
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
|
||||
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
||||
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
|
||||
- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention
|
||||
- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR
|
||||
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs
|
||||
- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure.
|
||||
|
|
|
|||
|
|
@ -1524,6 +1524,79 @@ class TextModel(ModelBase):
|
|||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size, (
|
||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||
)
|
||||
|
||||
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||
self.gguf_writer.add_token_merges(
|
||||
vocab.extract_vocab_merges_from_model()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if self.is_mistral_format and local_template_file_path.is_file():
|
||||
# Ministral-3 and other new Mistral models come with chat templates.
|
||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||
logger.info("Using an existing Mistral local chat template.")
|
||||
|
||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
|
||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||
if self.is_mistral_format:
|
||||
logger.info(
|
||||
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
|
||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||
)
|
||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||
else:
|
||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
template = None
|
||||
|
||||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
|
|
@ -2294,79 +2367,6 @@ class LlamaModel(TextModel):
|
|||
if self.hf_arch == "VLlama3ForCausalLM":
|
||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size, (
|
||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||
)
|
||||
|
||||
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||
self.gguf_writer.add_token_merges(
|
||||
vocab.extract_vocab_merges_from_model()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if self.is_mistral_format and local_template_file_path.is_file():
|
||||
# Ministral-3 and other new Mistral models come with chat templates.
|
||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||
logger.info("Using an existing Mistral local chat template.")
|
||||
|
||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
|
||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||
if self.is_mistral_format:
|
||||
logger.info(
|
||||
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
|
||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||
)
|
||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||
else:
|
||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
template = None
|
||||
|
||||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
def set_vocab(self):
|
||||
if self.is_mistral_format:
|
||||
return self._set_vocab_mistral()
|
||||
|
|
@ -9924,17 +9924,109 @@ class MistralModel(LlamaModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if "yarn" in self.hparams:
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||
|
||||
if "llama_4_scaling" in self.hparams:
|
||||
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
|
||||
@staticmethod
|
||||
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
|
||||
if "yarn" in hparams:
|
||||
yarn_params = hparams["yarn"]
|
||||
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in hparams:
|
||||
gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"])
|
||||
|
||||
|
||||
class MistralMoeModel(DeepseekV2Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
model_name = "Mistral"
|
||||
hf_arch = ""
|
||||
is_mistral_format = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
logger.info("Using MistralMoeModel")
|
||||
# remap hparams from Mistral MoE format to DeepseekV2 format
|
||||
# we do this way to be able to reuse DeepseekV2Model set_gguf_parameters logic
|
||||
# ref: https://github.com/vllm-project/vllm/blob/b294e28db2c5dee61bc25157664edcada8b90b31/vllm/transformers_utils/configs/mistral.py
|
||||
config = self.hparams
|
||||
# Mistral key -> HF key
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"n_heads": "num_attention_heads",
|
||||
"hidden_dim": "intermediate_size",
|
||||
}
|
||||
# HF key -> (Mistral key, default value)
|
||||
top_level_mapping_with_default = {
|
||||
"model_type": ("model_type", "transformer"),
|
||||
"hidden_act": ("activation", "silu"),
|
||||
"tie_word_embeddings": ("tied_embeddings", False),
|
||||
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
|
||||
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
||||
}
|
||||
# mapping top-level keys
|
||||
for key, new_key in config_mapping.items():
|
||||
if key in config:
|
||||
config[new_key] = config[key]
|
||||
for new_key, (key, default_value) in top_level_mapping_with_default.items():
|
||||
config[new_key] = config.get(key, default_value)
|
||||
# mapping MoE-specific keys
|
||||
moe_config_map = {
|
||||
"route_every_n": "moe_layer_freq",
|
||||
"first_k_dense_replace": "first_k_dense_replace",
|
||||
"num_experts_per_tok": "num_experts_per_tok",
|
||||
"num_experts": "n_routed_experts",
|
||||
"expert_hidden_dim": "moe_intermediate_size",
|
||||
"routed_scale": "routed_scaling_factor",
|
||||
"num_shared_experts": "n_shared_experts",
|
||||
"num_expert_groups": "n_group",
|
||||
"num_expert_groups_per_tok": "topk_group",
|
||||
}
|
||||
moe = config["moe"]
|
||||
for key, new_key in moe_config_map.items():
|
||||
if key in moe:
|
||||
config[new_key] = moe[key]
|
||||
# provide missing values
|
||||
config["topk_method"] = None
|
||||
config["norm_topk_prob"] = True
|
||||
config["scoring_func"] = "softmax"
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_mistral()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name:
|
||||
return []
|
||||
|
||||
# rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic
|
||||
if name.endswith(".qscale_act"):
|
||||
name = name.replace(".qscale_act", ".input_scale")
|
||||
if name.endswith(".qscale_weight"):
|
||||
name = name.replace(".qscale_weight", ".weight_scale")
|
||||
if ".wkv_b." in name:
|
||||
name = name.replace(".wkv_b.", ".kv_b_proj.")
|
||||
if ".experts." in name:
|
||||
name = name.replace(".experts.", ".mlp.experts.")
|
||||
name = name.replace(".w1.", ".gate_proj.")
|
||||
name = name.replace(".w2.", ".down_proj.")
|
||||
name = name.replace(".w3.", ".up_proj.")
|
||||
name = "model." + name
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
class PixtralModel(LlavaVisionModel):
|
||||
|
|
@ -10490,6 +10582,8 @@ def main() -> None:
|
|||
elif args.mmproj:
|
||||
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
|
||||
model_class = PixtralModel
|
||||
elif "moe" in hparams:
|
||||
model_class = MistralMoeModel
|
||||
else:
|
||||
model_class = MistralModel
|
||||
|
||||
|
|
|
|||
|
|
@ -241,6 +241,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch_free(batch);
|
||||
|
||||
// this one is managed by common_init_result
|
||||
//llama_free(ctx);
|
||||
|
||||
llama_free(ctx2);
|
||||
llama_free(ctx3);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
|
|
|
|||
|
|
@ -2196,6 +2196,15 @@ extern "C" {
|
|||
int p2,
|
||||
int p3);
|
||||
|
||||
// pad each dimension with values on the other side of the torus (looping around)
|
||||
GGML_API struct ggml_tensor * ggml_pad_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_pad_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
@ -2209,6 +2218,19 @@ extern "C" {
|
|||
int rp3
|
||||
);
|
||||
|
||||
// pad each dimension with values on the other side of the torus (looping around)
|
||||
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int lp0,
|
||||
int rp0,
|
||||
int lp1,
|
||||
int rp1,
|
||||
int lp2,
|
||||
int rp2,
|
||||
int lp3,
|
||||
int rp3);
|
||||
|
||||
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
||||
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
struct ggml_context * ctx,
|
||||
|
|
|
|||
|
|
@ -534,8 +534,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||
fs::path best_path;
|
||||
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||
if (std::error_code ec; !fs::exists(search_path, ec)) {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
|
||||
} else {
|
||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
|
|
@ -575,8 +579,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||
for (const auto & search_path : search_paths) {
|
||||
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
||||
fs::path path = search_path / filename;
|
||||
if (fs::exists(path)) {
|
||||
if (std::error_code ec; fs::exists(path, ec)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
} else {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(path).c_str(), ec.message().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -2551,6 +2551,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
case GGML_OP_ACC:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
return ggml_get_op_params_i32(op, 8) == 0;
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
|
|
|||
|
|
@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
|||
ggml_compute_forward_mul_mat(params, &dst);
|
||||
}
|
||||
|
||||
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
||||
return (coord + size) % size; // adding size avoids negative number weirdness
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_2d
|
||||
|
||||
|
||||
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
||||
const ggml_tensor * src, // [W, H, C, N]
|
||||
|
|
@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
|||
|
||||
// ggml_compute_forward_pad
|
||||
|
||||
template<bool circular_t>
|
||||
static void ggml_compute_forward_pad_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
|
@ -7615,23 +7621,40 @@ static void ggml_compute_forward_pad_f32(
|
|||
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
||||
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
||||
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
||||
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
// circular means wrap around on a torus, so x and y loop around
|
||||
if constexpr (circular_t) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
||||
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
||||
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
||||
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx =
|
||||
src_i3*nb03 +
|
||||
src_i2*nb02 +
|
||||
src_i1*nb01 +
|
||||
src_i0*nb00;
|
||||
|
||||
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
||||
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -7639,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_compute_forward_pad(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_pad_f32(params, dst);
|
||||
if (circular) {
|
||||
ggml_compute_forward_pad_f32<true>(params, dst);
|
||||
} else {
|
||||
ggml_compute_forward_pad_f32<false>(params, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
#include "pad.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
||||
// + size ensures negatives are handled properly
|
||||
return (coord + size) % size;
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular) {
|
||||
// blockIdx.z: i3*ne2+i2
|
||||
// blockIdx.y: i1
|
||||
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
||||
|
|
@ -12,61 +20,84 @@ static __global__ void pad_f32(const float * src, float * dst,
|
|||
int i1 = blockIdx.y;
|
||||
int i2 = blockIdx.z % ne2;
|
||||
int i3 = blockIdx.z / ne2;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
|
||||
|
||||
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
|
||||
if (!circular) {
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
// circular means on a torus, so x and y wrap around
|
||||
else {
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne03 = ne3 - lp3 - rp3;
|
||||
|
||||
const int64_t i00 = wrap_around(i0 - lp0, ne00);
|
||||
const int64_t i01 = wrap_around(i1 - lp1, ne01);
|
||||
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void pad_f32_cuda(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
ne0, ne1, ne2, ne3, circular);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(bool) circular, stream);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -574,22 +574,26 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) {
|
|||
// the requests stop after a certain amount of time (keep_alive_s) of inactivity
|
||||
dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0);
|
||||
dispatch_group_async(res->d_group, d_queue, ^{
|
||||
while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) {
|
||||
if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) {
|
||||
[res->lock lock];
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
|
||||
while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) {
|
||||
if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) {
|
||||
[res->lock lock];
|
||||
|
||||
for (int i = 0; i < (int) res->data.count; ++i) {
|
||||
[res->data[i] requestResidency];
|
||||
for (int i = 0; i < (int) res->data.count; ++i) {
|
||||
[res->data[i] requestResidency];
|
||||
}
|
||||
|
||||
atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed);
|
||||
|
||||
[res->lock unlock];
|
||||
}
|
||||
|
||||
atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed);
|
||||
|
||||
[res->lock unlock];
|
||||
// half a second
|
||||
usleep(500 * 1000);
|
||||
}
|
||||
|
||||
// half a second
|
||||
usleep(500 * 1000);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
});
|
||||
|
||||
return res;
|
||||
|
|
@ -600,6 +604,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
|
|||
return;
|
||||
}
|
||||
|
||||
// note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting
|
||||
GGML_ASSERT([rsets->data count] == 0);
|
||||
|
||||
atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed);
|
||||
|
|
@ -787,9 +792,6 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
|||
dev->rsets = nil;
|
||||
}
|
||||
|
||||
|
||||
// --------------------------------------------------
|
||||
|
||||
// print MTL GPU family:
|
||||
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
|
||||
|
||||
|
|
@ -1035,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_POOL_2D:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
|
|
|
|||
|
|
@ -3083,6 +3083,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
case GGML_OP_REPEAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE: {
|
||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
||||
|
|
|
|||
|
|
@ -1257,7 +1257,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|||
char hash_str[17];
|
||||
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||
if (!fs::exists(cache_file)) {
|
||||
std::error_code ec;
|
||||
if (!fs::exists(cache_file, ec)) {
|
||||
return false;
|
||||
}
|
||||
std::ifstream ifs(cache_file, std::ios::binary);
|
||||
|
|
|
|||
|
|
@ -4613,6 +4613,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
|
|
|||
|
|
@ -1050,6 +1050,7 @@ struct vk_op_pad_push_constants {
|
|||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||
uint32_t misalign_offsets;
|
||||
uint32_t circular;
|
||||
|
||||
uint32_t lp0; uint32_t rp0;
|
||||
uint32_t lp1; uint32_t rp1;
|
||||
|
|
@ -1092,6 +1093,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
|
|||
p.rp2 = dst->op_params[5];
|
||||
p.lp3 = dst->op_params[6];
|
||||
p.rp3 = dst->op_params[7];
|
||||
p.circular = dst->op_params[8];
|
||||
|
||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||
}
|
||||
|
|
@ -3537,7 +3539,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
SHADER_REDUCTION_MODE_SHMEM;
|
||||
|
||||
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
||||
|
|
@ -3561,7 +3563,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
||||
|
|
@ -3607,7 +3609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
|
|
@ -4033,10 +4035,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
for (auto &s : device->pipeline_solve_tri_f32) {
|
||||
const vk_solve_tri_pipeline_state &state = s.first;
|
||||
|
||||
// Max number of rows to load at a time, limited by shared memory
|
||||
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
|
||||
// Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
|
||||
const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
|
||||
|
||||
ggml_vk_create_pipeline(
|
||||
device, s.second, "solve_tri_f32",
|
||||
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
|
||||
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
|
||||
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
|
||||
}
|
||||
|
||||
#define IM2COL(bda) \
|
||||
|
|
@ -4174,9 +4182,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
|
||||
for (auto &c : compiles) {
|
||||
|
|
@ -5062,7 +5070,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_instance_validation_ext_available();
|
||||
static bool ggml_vk_instance_layer_settings_available();
|
||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
||||
|
|
@ -5091,19 +5099,19 @@ static void ggml_vk_instance_init() {
|
|||
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
|
||||
|
||||
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
|
||||
const bool validation_ext = ggml_vk_instance_validation_ext_available();
|
||||
const bool layer_settings = ggml_vk_instance_layer_settings_available();
|
||||
#ifdef __APPLE__
|
||||
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
||||
#endif
|
||||
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
|
||||
std::vector<const char*> layers;
|
||||
|
||||
if (validation_ext) {
|
||||
if (layer_settings) {
|
||||
layers.push_back("VK_LAYER_KHRONOS_validation");
|
||||
}
|
||||
std::vector<const char*> extensions;
|
||||
if (validation_ext) {
|
||||
extensions.push_back("VK_EXT_validation_features");
|
||||
if (layer_settings) {
|
||||
extensions.push_back("VK_EXT_layer_settings");
|
||||
}
|
||||
#ifdef __APPLE__
|
||||
if (portability_enumeration_ext) {
|
||||
|
|
@ -5113,26 +5121,24 @@ static void ggml_vk_instance_init() {
|
|||
if (debug_utils_ext) {
|
||||
extensions.push_back("VK_EXT_debug_utils");
|
||||
}
|
||||
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
|
||||
VkBool32 enable_best_practice = layer_settings;
|
||||
std::vector<vk::LayerSettingEXT> settings = {
|
||||
{
|
||||
"VK_LAYER_KHRONOS_validation",
|
||||
"validate_best_practices",
|
||||
vk::LayerSettingTypeEXT::eBool32,
|
||||
1,
|
||||
&enable_best_practice
|
||||
},
|
||||
};
|
||||
vk::LayerSettingsCreateInfoEXT layer_setting_info(settings);
|
||||
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info);
|
||||
#ifdef __APPLE__
|
||||
if (portability_enumeration_ext) {
|
||||
instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<vk::ValidationFeatureEnableEXT> features_enable;
|
||||
vk::ValidationFeaturesEXT validation_features;
|
||||
|
||||
if (validation_ext) {
|
||||
features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
|
||||
validation_features = {
|
||||
features_enable,
|
||||
{},
|
||||
};
|
||||
validation_features.setPNext(nullptr);
|
||||
instance_create_info.setPNext(&validation_features);
|
||||
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
|
||||
}
|
||||
vk_instance.instance = vk::createInstance(instance_create_info);
|
||||
vk_instance_initialized = true;
|
||||
|
||||
|
|
@ -14027,10 +14033,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
const uint32_t N = op->src[0]->ne[0];
|
||||
const uint32_t K = op->src[1]->ne[0];
|
||||
// K dimension limited to workgroup size
|
||||
if (K > 128) {
|
||||
if (K > 1u << device->max_workgroup_size_log2) {
|
||||
return false;
|
||||
}
|
||||
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
|
||||
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
|
||||
|
||||
if (batch_N == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -14227,21 +14235,21 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
|
|||
}
|
||||
|
||||
// Extension availability
|
||||
static bool ggml_vk_instance_validation_ext_available() {
|
||||
static bool ggml_vk_instance_layer_settings_available() {
|
||||
#ifdef GGML_VULKAN_VALIDATE
|
||||
// Check if validation layer provides the extension
|
||||
const std::string layer_name = "VK_LAYER_KHRONOS_validation";
|
||||
for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
|
||||
if (layer_name == layer.layerName.data()) {
|
||||
for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
|
||||
if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) {
|
||||
if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl;
|
||||
std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ layout (push_constant) uniform parameter
|
|||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint misalign_offsets;
|
||||
uint circular;
|
||||
|
||||
uint lp0; uint rp0;
|
||||
uint lp1; uint rp1;
|
||||
|
|
@ -18,6 +19,10 @@ layout (push_constant) uniform parameter
|
|||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
uint wrap_around(int coord, uint size) {
|
||||
return (uint(coord + int(size))) % size; // add size to avoid issues with negative
|
||||
}
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
|
|
@ -40,10 +45,20 @@ void main() {
|
|||
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
||||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
||||
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
if (p.circular != 0u) {
|
||||
const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);
|
||||
const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);
|
||||
const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);
|
||||
const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);
|
||||
const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);
|
||||
} else {
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
||||
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@
|
|||
|
||||
layout (constant_id = 1) const uint N = 64;
|
||||
layout (constant_id = 2) const uint K = 32;
|
||||
layout (constant_id = 3) const uint BATCH_N = 32;
|
||||
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
uint a_base, b_base, x_base;
|
||||
|
||||
|
|
@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) {
|
|||
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
|
||||
}
|
||||
|
||||
shared FLOAT_TYPE shA[N * N];
|
||||
shared FLOAT_TYPE shB[N * K];
|
||||
shared FLOAT_TYPE shA[BATCH_N * N];
|
||||
shared FLOAT_TYPE shB[BATCH_N * K];
|
||||
|
||||
void main() {
|
||||
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
|
|
@ -39,34 +40,42 @@ void main() {
|
|||
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
|
||||
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
|
||||
|
||||
// Load the A matrix into shA
|
||||
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
|
||||
uint idx = i + tid;
|
||||
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
|
||||
shA[idx] = get_a(idx / N, idx % N);
|
||||
}
|
||||
}
|
||||
// Load the B matrix into shB
|
||||
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
|
||||
uint idx = i + tid;
|
||||
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
|
||||
shB[idx] = get_b(idx / K, idx % K);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
FLOAT_TYPE X[N];
|
||||
// Each thread solves one column
|
||||
if (tid < K) {
|
||||
[[unroll]] for (int r = 0; r < N; ++r) {
|
||||
FLOAT_TYPE b = shB[r * K + tid];
|
||||
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
|
||||
[[unroll]] for (int c = 0; c < r; ++c) {
|
||||
b -= shA[r * N + c] * X[c];
|
||||
|
||||
// Loop over batches of rows
|
||||
[[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) {
|
||||
const uint cur_N = min(BATCH_N, N - row_base);
|
||||
|
||||
// Load the A matrix batch into shA
|
||||
[[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) {
|
||||
uint idx = i + tid;
|
||||
if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) {
|
||||
shA[idx] = get_a(row_base + idx / N, idx % N);
|
||||
}
|
||||
FLOAT_TYPE x = b / shA[r * N + r];
|
||||
X[r] = x;
|
||||
store_x(r, tid, x);
|
||||
}
|
||||
// Load the B matrix batch into shB
|
||||
[[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) {
|
||||
uint idx = i + tid;
|
||||
if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) {
|
||||
shB[idx] = get_b(row_base + idx / K, idx % K);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// Each thread solves one column
|
||||
if (tid < K) {
|
||||
[[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) {
|
||||
uint r = row_base + row_offset;
|
||||
FLOAT_TYPE b = shB[row_offset * K + tid];
|
||||
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
|
||||
[[unroll]] for (int c = 0; c < r; ++c) {
|
||||
b -= shA[row_offset * N + c] * X[c];
|
||||
}
|
||||
FLOAT_TYPE x = b / shA[row_offset * N + r];
|
||||
X[r] = x;
|
||||
store_x(r, tid, x);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit
|
|||
}
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -83,17 +83,18 @@ void main() {
|
|||
const uint logits_offset = n_experts * row;
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
|
||||
float wt[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
const uint expert = i + lane;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||
softmax_warp_inplace(wt, n_experts, lane, false);
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
|
|
@ -111,11 +112,11 @@ void main() {
|
|||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
uint max_expert = lane;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
|
||||
const uint expert = lane + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
|
|
@ -132,11 +133,11 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||
if ((k & (WARP_SIZE - 1)) == lane) {
|
||||
output_weights[k / WARP_SIZE] = max_val;
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||
if ((max_expert & (WARP_SIZE - 1)) == lane) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[ids_offset + k] = max_expert;
|
||||
|
|
@ -158,12 +159,12 @@ void main() {
|
|||
}
|
||||
|
||||
if (late_softmax) {
|
||||
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
uint idx = i * WARP_SIZE + lane;
|
||||
if (idx < n_expert_used) {
|
||||
weights[weights_offset + idx] = output_weights[i];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4947,6 +4947,18 @@ struct ggml_tensor * ggml_pad(
|
|||
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||
}
|
||||
|
||||
// ggml_pad_circular
|
||||
|
||||
struct ggml_tensor * ggml_pad_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3) {
|
||||
return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_pad_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
@ -4973,6 +4985,7 @@ struct ggml_tensor * ggml_pad_ext(
|
|||
ggml_set_op_params_i32(result, 5, rp2);
|
||||
ggml_set_op_params_i32(result, 6, lp3);
|
||||
ggml_set_op_params_i32(result, 7, rp3);
|
||||
ggml_set_op_params_i32(result, 8, 0); // not circular by default
|
||||
|
||||
|
||||
result->op = GGML_OP_PAD;
|
||||
|
|
@ -4981,6 +4994,25 @@ struct ggml_tensor * ggml_pad_ext(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_pad_ext_circular
|
||||
|
||||
struct ggml_tensor * ggml_pad_ext_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int lp0,
|
||||
int rp0,
|
||||
int lp1,
|
||||
int rp1,
|
||||
int lp2,
|
||||
int rp2,
|
||||
int lp3,
|
||||
int rp3
|
||||
) {
|
||||
struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_set_op_params_i32(result, 8, 1); // circular
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_pad_reflect_1d
|
||||
|
||||
struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
|
|
|
|||
|
|
@ -376,6 +376,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||
"layers.{bid}.gate", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
|
@ -450,6 +451,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w3", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_CHEXP: (
|
||||
|
|
@ -496,6 +498,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w1", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_CHEXP: (
|
||||
|
|
@ -557,6 +560,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w2", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
||||
|
|
@ -924,14 +928,17 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
||||
"layers.{bid}.attention.wq_a", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_B: (
|
||||
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
||||
"layers.{bid}.attention.wq_b", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||
"layers.{bid}.attention.wkv_a_with_mqa", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_B: (
|
||||
|
|
@ -940,18 +947,22 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.ATTN_K_B: (
|
||||
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
|
||||
"layers.{bid}.attention.k_b_proj", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_V_B: (
|
||||
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
|
||||
"layers.{bid}.attention.v_b_proj", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
"layers.{bid}.attention.q_a_norm", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
"layers.{bid}.attention.kv_a_norm", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||
|
|
|
|||
|
|
@ -666,7 +666,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
std::map<int, std::string> mapped;
|
||||
int blk_id = 0;
|
||||
int pruned_attention_w = 0;
|
||||
|
||||
// make a list of weights
|
||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||
|
|
@ -674,11 +673,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
for (const auto & it : ml.weights_map) {
|
||||
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
||||
if (remapped_name.empty()) {
|
||||
if (it.first.find("attn_v.weight") != std::string::npos ||
|
||||
it.first.find("attn_qkv.weight") != std::string::npos ||
|
||||
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
||||
pruned_attention_w++;
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||
continue;
|
||||
}
|
||||
|
|
@ -703,7 +697,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
});
|
||||
}
|
||||
|
||||
bool is_clip_model = false;
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
|
|
@ -717,30 +710,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
|
||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
int32_t n_layer_all = model.hparams.n_layer;
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_layer_all is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_layer_all += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -5604,21 +5604,24 @@ struct test_pad : public test_case {
|
|||
const std::array<int64_t, 4> ne_a;
|
||||
const int pad_0;
|
||||
const int pad_1;
|
||||
const bool circular;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
||||
return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular);
|
||||
}
|
||||
|
||||
test_pad(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
||||
int pad_0 = 1, int pad_1 = 1)
|
||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
||||
int pad_0 = 1, int pad_1 = 1, bool circular = false)
|
||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||
ggml_tensor * out = circular
|
||||
? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0)
|
||||
: ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -5638,17 +5641,19 @@ struct test_pad_ext : public test_case {
|
|||
const int lp3;
|
||||
const int rp3;
|
||||
const bool v;
|
||||
const bool circular;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v);
|
||||
return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular);
|
||||
}
|
||||
|
||||
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
||||
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
||||
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
||||
bool v = false)
|
||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {}
|
||||
bool v = false, bool circular = false)
|
||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
|
||||
v(v), circular(circular) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
|
|
@ -5659,7 +5664,9 @@ struct test_pad_ext : public test_case {
|
|||
ggml_set_name(a, "view of a");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_tensor * out = circular
|
||||
? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
|
||||
: ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -6204,6 +6211,15 @@ struct test_solve_tri : public test_case {
|
|||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
|
||||
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
int64_t n = ne_lhs[0];
|
||||
int64_t k = ne_rhs[0];
|
||||
int64_t batch = ne_lhs[2] * ne_lhs[3];
|
||||
// n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs
|
||||
return n * (n + 1) * k * batch;
|
||||
}
|
||||
|
||||
test_solve_tri(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
|
||||
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
|
||||
|
|
@ -7773,6 +7789,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||
test_cases.emplace_back(new test_acc());
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
|
||||
test_cases.emplace_back(new test_pad_ext());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
|
|
@ -7816,10 +7833,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
|
||||
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||
for (bool circular : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular));
|
||||
}
|
||||
}
|
||||
|
||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||
|
|
@ -8016,6 +8037,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
||||
// qwen3next with CHUNK_SIZE 64
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
|
||||
// qwen3next with CHUNK_SIZE 128
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
|
||||
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -3,6 +3,7 @@
|
|||
import { copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
import ChatMessageSystem from './ChatMessageSystem.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
|
|
@ -140,8 +141,7 @@
|
|||
}
|
||||
|
||||
function handleSaveEdit() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
onEditWithBranching?.(message, editedContent.trim());
|
||||
} else {
|
||||
// For assistant messages, preserve exact content including trailing whitespace
|
||||
|
|
@ -167,7 +167,28 @@
|
|||
}
|
||||
</script>
|
||||
|
||||
{#if message.role === 'user'}
|
||||
{#if message.role === 'system'}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{editedContent}
|
||||
{isEditing}
|
||||
{message}
|
||||
onCancelEdit={handleCancelEdit}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onEditKeydown={handleEditKeydown}
|
||||
onEditedContentChange={handleEditedContentChange}
|
||||
{onNavigateToSibling}
|
||||
onSaveEdit={handleSaveEdit}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if message.role === 'user'}
|
||||
<ChatMessageUser
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,216 @@
|
|||
<script lang="ts">
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
message: DatabaseMessage;
|
||||
isEditing: boolean;
|
||||
editedContent: string;
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
showDeleteDialog: boolean;
|
||||
deletionInfo: {
|
||||
totalCount: number;
|
||||
userMessages: number;
|
||||
assistantMessages: number;
|
||||
messageTypes: string[];
|
||||
} | null;
|
||||
onCancelEdit: () => void;
|
||||
onSaveEdit: () => void;
|
||||
onEditKeydown: (event: KeyboardEvent) => void;
|
||||
onEditedContentChange: (content: string) => void;
|
||||
onCopy: () => void;
|
||||
onEdit: () => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
textareaElement?: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
message,
|
||||
isEditing,
|
||||
editedContent,
|
||||
siblingInfo = null,
|
||||
showDeleteDialog,
|
||||
deletionInfo,
|
||||
onCancelEdit,
|
||||
onSaveEdit,
|
||||
onEditKeydown,
|
||||
onEditedContentChange,
|
||||
onCopy,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onConfirmDelete,
|
||||
onNavigateToSibling,
|
||||
onShowDeleteDialogChange,
|
||||
textareaElement = $bindable()
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
let isExpanded = $state(false);
|
||||
let contentHeight = $state(0);
|
||||
const MAX_HEIGHT = 200; // pixels
|
||||
const currentConfig = config();
|
||||
|
||||
let showExpandButton = $derived(contentHeight > MAX_HEIGHT);
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
|
||||
if (message.content.includes('\n')) {
|
||||
isMultiline = true;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const element = entry.target as HTMLElement;
|
||||
const estimatedSingleLineHeight = 24;
|
||||
|
||||
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
|
||||
contentHeight = element.scrollHeight;
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(messageElement);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
|
||||
function toggleExpand() {
|
||||
isExpanded = !isExpanded;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div
|
||||
aria-label="System message with actions"
|
||||
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
role="group"
|
||||
>
|
||||
{#if isEditing}
|
||||
<div class="w-full max-w-[80%]">
|
||||
<textarea
|
||||
bind:this={textareaElement}
|
||||
bind:value={editedContent}
|
||||
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
||||
placeholder="Edit system message..."
|
||||
></textarea>
|
||||
|
||||
<div class="mt-2 flex justify-end gap-2">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
||||
<X class="mr-1 h-3 w-3" />
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
{#if message.content.trim()}
|
||||
<div class="relative max-w-[80%]">
|
||||
<button
|
||||
class="group/expand w-full text-left {!isExpanded && showExpandButton
|
||||
? 'cursor-pointer'
|
||||
: 'cursor-auto'}"
|
||||
onclick={showExpandButton && !isExpanded ? toggleExpand : undefined}
|
||||
type="button"
|
||||
>
|
||||
<Card
|
||||
class="rounded-[1.125rem] !border-2 !border-dashed !border-border/50 bg-muted px-3.75 py-1.5 data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="border: 2px dashed hsl(var(--border));"
|
||||
>
|
||||
<div
|
||||
class="relative overflow-hidden transition-all duration-300 {isExpanded
|
||||
? 'cursor-text select-text'
|
||||
: 'select-none'}"
|
||||
style={!isExpanded && showExpandButton
|
||||
? `max-height: ${MAX_HEIGHT}px;`
|
||||
: 'max-height: none;'}
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement} class="text-md {isExpanded ? 'cursor-text' : ''}">
|
||||
<MarkdownContent class="markdown-system-content" content={message.content} />
|
||||
</div>
|
||||
{:else}
|
||||
<span
|
||||
bind:this={messageElement}
|
||||
class="text-md whitespace-pre-wrap {isExpanded ? 'cursor-text' : ''}"
|
||||
>
|
||||
{message.content}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if !isExpanded && showExpandButton}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-48 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-4 left-0 flex justify-center opacity-0 transition-opacity group-hover/expand:opacity-100"
|
||||
>
|
||||
<Button
|
||||
class="rounded-full px-4 py-1.5 text-xs shadow-md"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
Show full system message
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isExpanded && showExpandButton}
|
||||
<div class="mb-2 flex justify-center">
|
||||
<Button
|
||||
class="rounded-full px-4 py-1.5 text-xs"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
toggleExpand();
|
||||
}}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
Collapse System Message
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</Card>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if message.timestamp}
|
||||
<div class="max-w-[80%]">
|
||||
<ChatMessageActions
|
||||
actionsPosition="right"
|
||||
{deletionInfo}
|
||||
justify="end"
|
||||
{onConfirmDelete}
|
||||
{onCopy}
|
||||
{onDelete}
|
||||
{onEdit}
|
||||
{onNavigateToSibling}
|
||||
{onShowDeleteDialogChange}
|
||||
{siblingInfo}
|
||||
{showDeleteDialog}
|
||||
role="user"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
|
@ -145,7 +145,7 @@
|
|||
|
||||
{#if message.content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
class="max-w-[80%] rounded-[1.125rem] border-none bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import { ChatMessage } from '$lib/components/app';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -13,6 +14,7 @@
|
|||
let { class: className, messages = [], onUserAction }: Props = $props();
|
||||
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
const currentConfig = config();
|
||||
|
||||
function refreshAllMessages() {
|
||||
const conversation = activeConversation();
|
||||
|
|
@ -40,7 +42,12 @@
|
|||
return [];
|
||||
}
|
||||
|
||||
return messages.map((message) => {
|
||||
// Filter out system messages if showSystemMessage is false
|
||||
const filteredMessages = currentConfig.showSystemMessage
|
||||
? messages
|
||||
: messages.filter((msg) => msg.type !== 'system');
|
||||
|
||||
return filteredMessages.map((message) => {
|
||||
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -36,12 +36,6 @@
|
|||
title: 'General',
|
||||
icon: Settings,
|
||||
fields: [
|
||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||
{
|
||||
key: 'systemMessage',
|
||||
label: 'System Message (will be disabled if left empty)',
|
||||
type: 'textarea'
|
||||
},
|
||||
{
|
||||
key: 'theme',
|
||||
label: 'Theme',
|
||||
|
|
@ -52,6 +46,12 @@
|
|||
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||
]
|
||||
},
|
||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||
{
|
||||
key: 'systemMessage',
|
||||
label: 'System Message',
|
||||
type: 'textarea'
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@
|
|||
</div>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
{@html field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
|
|
@ -112,13 +112,28 @@
|
|||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="min-h-[100px] w-full md:max-w-2xl"
|
||||
class="min-h-[10rem] w-full md:max-w-2xl"
|
||||
/>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
|
||||
{#if field.key === 'systemMessage'}
|
||||
<div class="mt-3 flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="showSystemMessage"
|
||||
checked={Boolean(localConfig.showSystemMessage ?? true)}
|
||||
onCheckedChange={(checked) => onConfigChange('showSystemMessage', Boolean(checked))}
|
||||
/>
|
||||
|
||||
<Label for="showSystemMessage" class="cursor-pointer text-sm font-normal">
|
||||
Show system message in conversations
|
||||
</Label>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if field.type === 'select'}
|
||||
{@const selectedOption = field.options?.find(
|
||||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
|
|
@ -98,6 +99,10 @@
|
|||
|
||||
await goto(`#/chat/${id}`);
|
||||
}
|
||||
|
||||
function handleStopGeneration(id: string) {
|
||||
chatStore.stopGenerationForChat(id);
|
||||
}
|
||||
</script>
|
||||
|
||||
<ScrollArea class="h-[100vh]">
|
||||
|
|
@ -132,6 +137,7 @@
|
|||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2 } from '@lucide/svelte';
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
|
||||
import { ActionDropdown } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
|
|
@ -12,6 +13,7 @@
|
|||
onDelete?: (id: string) => void;
|
||||
onEdit?: (id: string) => void;
|
||||
onSelect?: (id: string) => void;
|
||||
onStop?: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
|
|
@ -20,6 +22,7 @@
|
|||
onDelete,
|
||||
onEdit,
|
||||
onSelect,
|
||||
onStop,
|
||||
isActive = false
|
||||
}: Props = $props();
|
||||
|
||||
|
|
@ -38,8 +41,14 @@
|
|||
onDelete?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleStop(event: Event) {
|
||||
event.stopPropagation();
|
||||
onStop?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleGlobalEditEvent(event: Event) {
|
||||
const customEvent = event as CustomEvent<{ conversationId: string }>;
|
||||
|
||||
if (customEvent.detail.conversationId === conversation.id && isActive) {
|
||||
handleEdit(event);
|
||||
}
|
||||
|
|
@ -88,8 +97,28 @@
|
|||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if isLoading}
|
||||
<Loader2 class="h-3.5 w-3.5 shrink-0 animate-spin text-muted-foreground" />
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</div>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>Stop generation</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||
|
|
@ -147,5 +176,25 @@
|
|||
opacity: 1 !important;
|
||||
}
|
||||
}
|
||||
|
||||
.stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: none;
|
||||
}
|
||||
|
||||
:global(.loading-icon) {
|
||||
display: block;
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
:global(.loading-icon) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@ export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
|
|||
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
|
||||
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
|
||||
export { default as ChatMessageSystem } from './chat/ChatMessages/ChatMessageSystem.svelte';
|
||||
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
||||
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
||||
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
|
||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||
|
|
|
|||
|
|
@ -337,19 +337,23 @@
|
|||
line-height: 1.75;
|
||||
}
|
||||
|
||||
div :global(:is(h1, h2, h3, h4, h5, h6):first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Headers with consistent spacing */
|
||||
div :global(h1) {
|
||||
font-size: 1.875rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
line-height: 1.2;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
}
|
||||
|
||||
div :global(h2) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
line-height: 1.3;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
div :global(h3) {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||
apiKey: '',
|
||||
systemMessage: '',
|
||||
showSystemMessage: true,
|
||||
theme: 'system',
|
||||
showThoughtInProgress: false,
|
||||
showToolCalls: false,
|
||||
|
|
@ -42,8 +43,9 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
};
|
||||
|
||||
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
apiKey: 'Set the API Key if you are using <code>--api-key</code> option for the server.',
|
||||
systemMessage: 'The starting message that defines how model should behave.',
|
||||
showSystemMessage: 'Display the system message at the top of each conversation.',
|
||||
theme:
|
||||
'Choose the color theme for the interface. You can choose between System (follows your device settings), Light, or Dark.',
|
||||
pasteLongTextToFileLen:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ export class ChatService {
|
|||
custom,
|
||||
timings_per_token,
|
||||
// Config options
|
||||
systemMessage,
|
||||
disableReasoningFormat
|
||||
} = options;
|
||||
|
||||
|
|
@ -103,6 +102,7 @@ export class ChatService {
|
|||
}
|
||||
})
|
||||
.filter((msg) => {
|
||||
// Filter out empty system messages
|
||||
if (msg.role === 'system') {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
|
||||
|
|
@ -112,10 +112,8 @@ export class ChatService {
|
|||
return true;
|
||||
});
|
||||
|
||||
const processedMessages = ChatService.injectSystemMessage(normalizedMessages, systemMessage);
|
||||
|
||||
const requestBody: ApiChatCompletionRequest = {
|
||||
messages: processedMessages.map((msg: ApiChatMessageData) => ({
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => ({
|
||||
role: msg.role,
|
||||
content: msg.content
|
||||
})),
|
||||
|
|
@ -677,46 +675,6 @@ export class ChatService {
|
|||
// Utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Injects a system message at the beginning of the conversation if provided.
|
||||
* Checks for existing system messages to avoid duplication.
|
||||
*
|
||||
* @param messages - Array of chat messages to process
|
||||
* @param systemMessage - Optional system message to inject
|
||||
* @returns Array of messages with system message injected at the beginning if provided
|
||||
* @private
|
||||
*/
|
||||
private static injectSystemMessage(
|
||||
messages: ApiChatMessageData[],
|
||||
systemMessage?: string
|
||||
): ApiChatMessageData[] {
|
||||
const trimmedSystemMessage = systemMessage?.trim();
|
||||
|
||||
if (!trimmedSystemMessage) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
if (messages.length > 0 && messages[0].role === 'system') {
|
||||
if (messages[0].content !== trimmedSystemMessage) {
|
||||
const updatedMessages = [...messages];
|
||||
updatedMessages[0] = {
|
||||
role: 'system',
|
||||
content: trimmedSystemMessage
|
||||
};
|
||||
return updatedMessages;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
const systemMsg: ApiChatMessageData = {
|
||||
role: 'system',
|
||||
content: trimmedSystemMessage
|
||||
};
|
||||
|
||||
return [systemMsg, ...messages];
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses error response and creates appropriate error with context information
|
||||
* @param response - HTTP response object
|
||||
|
|
|
|||
|
|
@ -166,6 +166,49 @@ export class DatabaseService {
|
|||
return rootMessage.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a system prompt message for a conversation.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param systemPrompt - The system prompt content (must be non-empty)
|
||||
* @param parentId - Parent message ID (typically the root message)
|
||||
* @returns The created system message
|
||||
* @throws Error if systemPrompt is empty
|
||||
*/
|
||||
static async createSystemMessage(
|
||||
convId: string,
|
||||
systemPrompt: string,
|
||||
parentId: string
|
||||
): Promise<DatabaseMessage> {
|
||||
const trimmedPrompt = systemPrompt.trim();
|
||||
if (!trimmedPrompt) {
|
||||
throw new Error('Cannot create system message with empty content');
|
||||
}
|
||||
|
||||
const systemMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: 'system',
|
||||
timestamp: Date.now(),
|
||||
role: 'system',
|
||||
content: trimmedPrompt,
|
||||
parent: parentId,
|
||||
thinking: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(systemMessage);
|
||||
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, systemMessage.id]
|
||||
});
|
||||
}
|
||||
|
||||
return systemMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation and all its messages.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ import { DatabaseService, ChatService } from '$lib/services';
|
|||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { contextSize, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { selectedModelName, modelsStore } from '$lib/stores/models.svelte';
|
||||
import {
|
||||
selectedModelName,
|
||||
modelsStore,
|
||||
selectedModelContextSize
|
||||
} from '$lib/stores/models.svelte';
|
||||
import {
|
||||
normalizeModelName,
|
||||
filterByLeafNodeId,
|
||||
|
|
@ -261,6 +265,13 @@ class ChatStore {
|
|||
return activeState.contextTotal;
|
||||
}
|
||||
|
||||
if (isRouterMode()) {
|
||||
const modelContextSize = selectedModelContextSize();
|
||||
if (modelContextSize && modelContextSize > 0) {
|
||||
return modelContextSize;
|
||||
}
|
||||
}
|
||||
|
||||
const propsContextSize = contextSize();
|
||||
if (propsContextSize && propsContextSize > 0) {
|
||||
return propsContextSize;
|
||||
|
|
@ -458,6 +469,14 @@ class ChatStore {
|
|||
onError?: (error: Error) => void,
|
||||
modelOverride?: string | null
|
||||
): Promise<void> {
|
||||
// Ensure model props are cached before streaming (for correct n_ctx in processing info)
|
||||
if (isRouterMode()) {
|
||||
const modelName = modelOverride || selectedModelName();
|
||||
if (modelName && !modelsStore.getModelProps(modelName)) {
|
||||
await modelsStore.fetchModelProps(modelName);
|
||||
}
|
||||
}
|
||||
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
let streamedToolCallContent = '';
|
||||
|
|
@ -624,6 +643,22 @@ class ChatStore {
|
|||
this.clearChatStreaming(currentConv.id);
|
||||
|
||||
try {
|
||||
if (isNewConversation) {
|
||||
const rootId = await DatabaseService.createRootMessage(currentConv.id);
|
||||
const currentConfig = config();
|
||||
const systemPrompt = currentConfig.systemMessage?.toString().trim();
|
||||
|
||||
if (systemPrompt) {
|
||||
const systemMessage = await DatabaseService.createSystemMessage(
|
||||
currentConv.id,
|
||||
systemPrompt,
|
||||
rootId
|
||||
);
|
||||
|
||||
conversationsStore.addMessageToActive(systemMessage);
|
||||
}
|
||||
}
|
||||
|
||||
const userMessage = await this.addMessage('user', content, 'text', '-1', extras);
|
||||
if (!userMessage) throw new Error('Failed to add user message');
|
||||
if (isNewConversation && content)
|
||||
|
|
@ -666,13 +701,17 @@ class ChatStore {
|
|||
|
||||
if (!activeConv) return;
|
||||
|
||||
await this.savePartialResponseIfNeeded(activeConv.id);
|
||||
await this.stopGenerationForChat(activeConv.id);
|
||||
}
|
||||
|
||||
async stopGenerationForChat(convId: string): Promise<void> {
|
||||
await this.savePartialResponseIfNeeded(convId);
|
||||
|
||||
this.stopStreaming();
|
||||
this.abortRequest(activeConv.id);
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
this.clearProcessingState(activeConv.id);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
this.clearProcessingState(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -999,14 +1038,20 @@ class ChatStore {
|
|||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
let result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
|
||||
if (!result) {
|
||||
result = this.getMessageByIdWithRole(messageId, 'system');
|
||||
}
|
||||
|
||||
if (!result) return;
|
||||
const { message: msg } = result;
|
||||
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id;
|
||||
const isFirstUserMessage =
|
||||
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
|
||||
|
||||
const parentId = msg.parent || rootMessage?.id;
|
||||
if (!parentId) return;
|
||||
|
|
@ -1037,7 +1082,10 @@ class ChatStore {
|
|||
);
|
||||
}
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
await this.generateResponseForMessage(newMessage.id);
|
||||
|
||||
if (msg.role === 'user') {
|
||||
await this.generateResponseForMessage(newMessage.id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to edit message with branching:', error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -158,6 +158,22 @@ class ModelsStore {
|
|||
return this.modelPropsCache.get(modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size (n_ctx) for a specific model from cached props
|
||||
*/
|
||||
getModelContextSize(modelId: string): number | null {
|
||||
const props = this.modelPropsCache.get(modelId);
|
||||
return props?.default_generation_settings?.n_ctx ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size for the currently selected model or null if no model is selected
|
||||
*/
|
||||
get selectedModelContextSize(): number | null {
|
||||
if (!this.selectedModelName) return null;
|
||||
return this.getModelContextSize(this.selectedModelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if props are being fetched for a model
|
||||
*/
|
||||
|
|
@ -579,3 +595,4 @@ export const loadedModelIds = () => modelsStore.loadedModelIds;
|
|||
export const loadingModelIds = () => modelsStore.loadingModelIds;
|
||||
export const propsCacheVersion = () => modelsStore.propsCacheVersion;
|
||||
export const singleModelName = () => modelsStore.singleModelName;
|
||||
export const selectedModelContextSize = () => modelsStore.selectedModelContextSize;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
export type ChatMessageType = 'root' | 'text' | 'think';
|
||||
export type ChatMessageType = 'root' | 'text' | 'think' | 'system';
|
||||
export type ChatRole = 'user' | 'assistant' | 'system';
|
||||
|
||||
export interface ChatUploadedFile {
|
||||
|
|
|
|||
Loading…
Reference in New Issue