Compare commits
14 Commits
dea085e9d7
...
33902ef1b6
| Author | SHA1 | Date |
|---|---|---|
|
|
33902ef1b6 | |
|
|
aa00911d12 | |
|
|
ce8fd4b1a6 | |
|
|
9f5e1edb10 | |
|
|
920b3e78cb | |
|
|
974c8c94cc | |
|
|
227ed28e12 | |
|
|
bafae27654 | |
|
|
873c825611 | |
|
|
82764d8f40 | |
|
|
21a4933042 | |
|
|
1e9d771e2c | |
|
|
aa4695c5e5 | |
|
|
e459796110 |
|
|
@ -258,6 +258,9 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
|
||||
if (callback) {
|
||||
callback->on_update(p);
|
||||
if (callback->is_cancelled()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress_step = 0;
|
||||
}
|
||||
|
|
@ -373,6 +376,9 @@ static int common_download_file_single_online(const std::string & url,
|
|||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (opts.callback && opts.callback->is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
|
|
@ -412,6 +418,12 @@ static int common_download_file_single_online(const std::string & url,
|
|||
if (opts.callback) {
|
||||
opts.callback->on_done(p, success);
|
||||
}
|
||||
if (opts.callback && opts.callback->is_cancelled() &&
|
||||
std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ public:
|
|||
virtual void on_start(const common_download_progress & p) = 0;
|
||||
virtual void on_update(const common_download_progress & p) = 0;
|
||||
virtual void on_done(const common_download_progress & p, bool ok) = 0;
|
||||
virtual bool is_cancelled() const { return false; }
|
||||
};
|
||||
|
||||
struct common_remote_params {
|
||||
|
|
|
|||
|
|
@ -4258,9 +4258,7 @@ class Qwen2VLVisionModel(MmprojModel):
|
|||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2_5OmniModel")
|
||||
class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
has_vision_encoder = True
|
||||
class Qwen25AudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -4276,12 +4274,6 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
|||
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# SinusoidsPositionEmbedding
|
||||
assert self.hparams_audio is not None
|
||||
|
|
@ -4312,7 +4304,32 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
|||
# this tensor is left unused in transformers code
|
||||
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2_5OmniModel")
|
||||
class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "visual." in name:
|
||||
yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
|
||||
elif "audio_tower." in name:
|
||||
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("InternVisionModel")
|
||||
|
|
@ -4816,7 +4833,10 @@ class RND1Model(Qwen2MoeModel):
|
|||
class Qwen3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
if self.hparams_vision is None:
|
||||
logger.info("No vision config found, skipping vision tensor processing")
|
||||
return
|
||||
|
||||
# Compute image_size if not present
|
||||
if "image_size" not in self.hparams_vision:
|
||||
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
|
||||
|
|
@ -4837,7 +4857,9 @@ class Qwen3VLVisionModel(MmprojModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
# in case mixed modalities, the arch will be handled by subclass
|
||||
if not self.has_audio_encoder:
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
|
||||
if self.hparams_vision is not None:
|
||||
|
|
@ -4925,11 +4947,64 @@ class Qwen3VLVisionModel(MmprojModel):
|
|||
return
|
||||
|
||||
if name.startswith("visual."):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return
|
||||
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
# Fall back to parent class for other tensors
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
|
||||
class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
if self.has_vision_encoder:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
if self.has_audio_encoder:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if self.has_vision_encoder:
|
||||
Qwen3VLVisionModel.set_gguf_parameters(self)
|
||||
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
if self.has_audio_encoder:
|
||||
Qwen25AudioModel.set_gguf_parameters(self)
|
||||
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "visual." in name:
|
||||
if not self.has_vision_encoder:
|
||||
raise ValueError(f"Model does not have vision encoder, but found tensor {name}")
|
||||
# need to transform vision tensor naming, so that modify_tensors() logic can be used correctly
|
||||
name = name.replace("thinker.visual.", "model.visual.")
|
||||
if ".merger_list." in name:
|
||||
name = name.replace(".merger_list.", ".deepstack_merger_list.")
|
||||
name = name.replace(".ln_q", ".norm")
|
||||
name = name.replace(".mlp.0", ".linear_fc1")
|
||||
name = name.replace(".mlp.2", ".linear_fc2")
|
||||
elif ".merger." in name:
|
||||
name = name.replace(".ln_q", ".norm")
|
||||
name = name.replace(".mlp.0", ".linear_fc1")
|
||||
name = name.replace(".mlp.2", ".linear_fc2")
|
||||
yield from Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid)
|
||||
elif "audio_tower." in name:
|
||||
if not self.has_audio_encoder:
|
||||
raise ValueError(f"Model does not have audio encoder, but found tensor {name}")
|
||||
if "conv2d" in name and name.endswith(".bias"):
|
||||
# transform conv2d bias [n_embd] --> [1, 1, n_embd]
|
||||
data_torch = data_torch.unsqueeze(-1).unsqueeze(-1)
|
||||
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3ASRForConditionalGeneration")
|
||||
class Qwen3ASRMmprojModel(Qwen3OmniMmprojModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = False
|
||||
|
||||
|
||||
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
|
||||
|
|
@ -4992,6 +5067,8 @@ class Step3VLVisionModel(MmprojModel):
|
|||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if ("mm.0." in new_name or "mm.1." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
|
@ -5030,9 +5107,10 @@ class Qwen3VLTextModel(Qwen3Model):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
if "thinker_config" in self.hparams:
|
||||
vision_config = self.hparams["thinker_config"].get("vision_config", {})
|
||||
else:
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||
|
||||
|
|
@ -5101,6 +5179,70 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
|||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
|
||||
class Qwen3OmniMoeTextModel(Qwen3VLMoeTextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# correct BOS/EOS tokens
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
|
||||
for token_id, data in added_tokens.items():
|
||||
if data.get("content") == "<|im_end|>":
|
||||
self.gguf_writer.add_bos_token_id(int(token_id))
|
||||
self.gguf_writer.add_eos_token_id(int(token_id))
|
||||
break
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_num_deepstack_layers(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision and audio tensors - they go in the mmproj file
|
||||
if "visual." in name or "audio_tower." in name \
|
||||
or "talker." in name or "code2wav." in name:
|
||||
return
|
||||
|
||||
name = name.replace("thinker.", "")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3ASRForConditionalGeneration")
|
||||
class Qwen3ASRTextModel(Qwen3VLTextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_num_deepstack_layers(0)
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# fix chat template, use correct chatml format
|
||||
self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}")
|
||||
# correct BOS/EOS tokens
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
|
||||
for token_id, data in added_tokens.items():
|
||||
if data.get("content") == "<|im_end|>":
|
||||
self.gguf_writer.add_bos_token_id(int(token_id))
|
||||
self.gguf_writer.add_eos_token_id(int(token_id))
|
||||
break
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
# qwen3-omni
|
||||
name = name.replace("thinker.", "")
|
||||
|
||||
# Skip vision and audio tensors - they go in the mmproj file
|
||||
if "visual." in name or "audio_tower." in name \
|
||||
or "talker." in name or "code2wav." in name:
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
class _LinearAttentionVReorderBase(Qwen3NextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses
|
||||
"""reorders V heads from grouped to tiled order for ggml broadcast
|
||||
|
|
|
|||
|
|
@ -94,6 +94,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
|||
# Moondream2 20250414 version
|
||||
(tool_name) -hf ggml-org/moondream2-20250414-GGUF
|
||||
|
||||
# Gemma 4
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-E4B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-26B-A4B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-31B-it-GGUF
|
||||
```
|
||||
|
||||
**Audio models**:
|
||||
|
|
@ -118,6 +123,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
|||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
|
||||
|
||||
# Gemma 4
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-E4B-it-GGUF
|
||||
```
|
||||
|
||||
## Finding more models:
|
||||
|
|
|
|||
|
|
@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
bool is_capturing = false;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does.
|
||||
// See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149
|
||||
// TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release.
|
||||
cudaStreamCaptureStatus capture_status;
|
||||
CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status));
|
||||
is_capturing = (capture_status != cudaStreamCaptureStatusNone);
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,8 +133,16 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
|||
|
||||
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
|
||||
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
|
||||
// Auto-disable concurrent dispatch on non-Apple GPUs (AMD/Intel)
|
||||
// MTLDispatchTypeConcurrent has broken memory barriers on these GPUs
|
||||
bool is_apple_gpu = props_dev->supports_gpu_family_apple7;
|
||||
res->use_concurrency = is_apple_gpu && (getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil);
|
||||
|
||||
if (!is_apple_gpu && getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil) {
|
||||
GGML_LOG_INFO("%s: disabling concurrent dispatch (non-Apple GPU detected)\n", __func__);
|
||||
}
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
|
||||
|
|
|
|||
|
|
@ -228,6 +228,7 @@ struct ggml_metal_device_props {
|
|||
bool has_tensor;
|
||||
bool use_residency_sets;
|
||||
bool use_shared_buffers;
|
||||
bool use_managed_buffers; // Use Managed mode for discrete GPUs
|
||||
|
||||
bool supports_gpu_family_apple7;
|
||||
|
||||
|
|
|
|||
|
|
@ -787,6 +787,17 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
|||
dev->props.use_shared_buffers = true;
|
||||
}
|
||||
|
||||
// Use Managed mode on discrete GPUs for cached PCIe reads
|
||||
dev->props.use_managed_buffers = !dev->props.has_unified_memory;
|
||||
|
||||
// Environment variable overrides for testing
|
||||
if (getenv("GGML_METAL_MANAGED_BUFFERS_DISABLE") != NULL) {
|
||||
dev->props.use_managed_buffers = false;
|
||||
}
|
||||
if (getenv("GGML_METAL_MANAGED_BUFFERS_ENABLE") != NULL) {
|
||||
dev->props.use_managed_buffers = true;
|
||||
}
|
||||
|
||||
dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
|
||||
dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
|
||||
|
|
@ -849,6 +860,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
|||
GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use managed buffers = %s\n", __func__, dev->props.use_managed_buffers ? "true" : "false");
|
||||
|
||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||
|
|
@ -1438,10 +1450,19 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
|||
|
||||
if (size_aligned > 0) {
|
||||
if (props_dev->use_shared_buffers && shared) {
|
||||
MTLResourceOptions storage_mode = props_dev->use_managed_buffers
|
||||
? MTLResourceStorageModeManaged
|
||||
: MTLResourceStorageModeShared;
|
||||
|
||||
res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data
|
||||
length:size_aligned
|
||||
options:MTLResourceStorageModeShared
|
||||
options:storage_mode
|
||||
deallocator:nil];
|
||||
|
||||
// For Managed buffers, sync CPU→GPU after creation
|
||||
if (props_dev->use_managed_buffers && res->buffers[0].metal) {
|
||||
[(id<MTLBuffer>)res->buffers[0].metal didModifyRange:NSMakeRange(0, size_aligned)];
|
||||
}
|
||||
} else {
|
||||
res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
||||
}
|
||||
|
|
@ -1507,13 +1528,22 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
|||
res->buffers[res->n_buffers].metal = nil;
|
||||
|
||||
if (size_aligned > 0) {
|
||||
res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
MTLResourceOptions storage_mode = props_dev->use_managed_buffers
|
||||
? MTLResourceStorageModeManaged
|
||||
: MTLResourceStorageModeShared;
|
||||
|
||||
res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:storage_mode deallocator:nil];
|
||||
|
||||
if (res->buffers[res->n_buffers].metal == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
||||
free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// For Managed buffers, sync CPU→GPU after creation
|
||||
if (props_dev->use_managed_buffers) {
|
||||
[(id<MTLBuffer>)res->buffers[res->n_buffers].metal didModifyRange:NSMakeRange(0, size_aligned)];
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned);
|
||||
|
|
@ -1534,13 +1564,22 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
|||
res->buffers[res->n_buffers].metal = nil;
|
||||
|
||||
if (size_step_aligned > 0) {
|
||||
res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
MTLResourceOptions storage_mode = props_dev->use_managed_buffers
|
||||
? MTLResourceStorageModeManaged
|
||||
: MTLResourceStorageModeShared;
|
||||
|
||||
res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:storage_mode deallocator:nil];
|
||||
|
||||
if (res->buffers[res->n_buffers].metal == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
||||
free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// For Managed buffers, sync CPU→GPU after creation
|
||||
if (props_dev->use_managed_buffers) {
|
||||
[(id<MTLBuffer>)res->buffers[res->n_buffers].metal didModifyRange:NSMakeRange(0, size_step_aligned)];
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned);
|
||||
|
|
@ -1597,6 +1636,13 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
|
|||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memset((char *) tensor->data + offset, value, size);
|
||||
|
||||
// Sync Managed buffer after CPU write
|
||||
if (buf->dev->props.use_managed_buffers) {
|
||||
struct ggml_metal_buffer_id bid = ggml_metal_buffer_get_id(buf, tensor);
|
||||
[(id<MTLBuffer>)bid.metal didModifyRange:NSMakeRange(bid.offs + offset, size)];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1625,6 +1671,13 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
|
|||
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy((char *) tensor->data + offset, data, size);
|
||||
|
||||
// Sync Managed buffer after CPU write
|
||||
if (buf->dev->props.use_managed_buffers) {
|
||||
struct ggml_metal_buffer_id bid = ggml_metal_buffer_get_id(buf, tensor);
|
||||
[(id<MTLBuffer>)bid.metal didModifyRange:NSMakeRange(bid.offs + offset, size)];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1661,9 +1714,6 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
|||
}
|
||||
|
||||
[cmd_buf addCompletedHandler:^(id<MTLCommandBuffer> cb) {
|
||||
// TODO: can check for errors here
|
||||
GGML_UNUSED(cb);
|
||||
|
||||
dispatch_semaphore_signal(completion_semaphore);
|
||||
}];
|
||||
|
||||
|
|
@ -1678,6 +1728,14 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
|||
|
||||
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
// For Managed buffers, sync GPU→CPU then direct memcpy
|
||||
if (buf->dev->props.use_managed_buffers) {
|
||||
struct ggml_metal_buffer_id bid = ggml_metal_buffer_get_id(buf, tensor);
|
||||
@autoreleasepool {
|
||||
[(id<MTLBuffer>)bid.metal synchronizeResource];
|
||||
}
|
||||
}
|
||||
// Direct memcpy (Shared or Managed after sync)
|
||||
memcpy(data, (const char *) tensor->data + offset, size);
|
||||
return;
|
||||
}
|
||||
|
|
@ -1715,7 +1773,9 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
|
|||
}
|
||||
|
||||
void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
|
||||
if (buf->is_shared) {
|
||||
// For Managed buffers, use GPU blit to avoid reading unsynced data
|
||||
if (buf->is_shared && !buf->dev->props.use_managed_buffers) {
|
||||
// True Shared mode (unified memory): direct memset OK
|
||||
memset(buf->all_data, value, buf->all_size);
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -488,7 +488,7 @@ static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t
|
|||
const int nb = k / QK_NVFP4;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
dequantize_block_nvfp4(vx, y, k);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#define GGML_SYCL_DEQUANTIZE_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
#include "convert.hpp"
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
|
||||
|
|
|
|||
|
|
@ -355,7 +355,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -176,14 +176,12 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||
const sycl::uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||
|
||||
int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
|
||||
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
{
|
||||
constexpr int sv = 16;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
|
@ -194,7 +192,7 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
{
|
||||
constexpr int sv = 32;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
|
@ -205,7 +203,7 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
{
|
||||
constexpr int sv = 64;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
|
@ -217,7 +215,7 @@ static void launch_gated_delta_net(const float * q_d,
|
|||
{
|
||||
constexpr int sv = 128;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
|
|
|||
|
|
@ -4727,12 +4727,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
struct ggml_tensor * a = op->src[0];
|
||||
struct ggml_tensor * b = op->src[1];
|
||||
|
||||
// disable Q1_0 until implementation
|
||||
if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
||||
|
||||
|
||||
// TODO: The configuration below needs more work to be supported with oneDNN
|
||||
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
||||
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ static void upscale_f32_sycl(const float * x,
|
|||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
});
|
||||
}
|
||||
|
|
@ -304,7 +304,7 @@ static void upscale_f32_bilinear_sycl(const float * x,
|
|||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bilinear_antialias(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
|
|
@ -314,7 +314,7 @@ static void upscale_f32_bilinear_sycl(const float * x,
|
|||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bilinear(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst,
|
||||
ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
|
|
@ -349,7 +349,7 @@ static void upscale_f32_bicubic_sycl(const float * x,
|
|||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bicubic(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
|
|
|
|||
|
|
@ -534,11 +534,7 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
|||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->debug_host_buf.GetSize())) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
|
||||
return;
|
||||
}
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
|
|
|
|||
|
|
@ -798,6 +798,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
A_ENC_INP_PROJ = auto() # gemma4
|
||||
A_ENC_CONV1D = auto()
|
||||
A_ENC_CONV1D_NORM = auto() # gemma3n
|
||||
A_ENC_CONV2D = auto()
|
||||
A_ENC_CONV_OUT = auto()
|
||||
A_PRE_NORM = auto()
|
||||
A_POST_NORM = auto()
|
||||
A_ENC_LAYER_PRE_NORM = auto() # gemma3n
|
||||
|
|
@ -1280,6 +1282,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ: "a.input_projection",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV2D: "a.conv2d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out",
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
|
||||
|
|
@ -1426,6 +1430,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_ENC_CONV2D,
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT,
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
MODEL_TENSOR.A_POST_NORM,
|
||||
|
|
@ -4112,6 +4118,7 @@ class VisionProjectorType:
|
|||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
QWEN3A = "qwen3a" # audio
|
||||
GLMA = "glma" # audio
|
||||
QWEN25O = "qwen2.5o" # omni
|
||||
VOXTRAL = "voxtral"
|
||||
|
|
|
|||
|
|
@ -1892,6 +1892,14 @@ class TensorNameMap:
|
|||
"conformer.subsample_conv_projection.input_proj_linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV2D: (
|
||||
"audio_tower.conv2d{bid}", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT: (
|
||||
"audio_tower.conv_out", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PRE_NORM: (),
|
||||
|
||||
MODEL_TENSOR.A_POST_NORM: (
|
||||
|
|
@ -2042,7 +2050,8 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.A_MMPROJ: (
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox, meralion
|
||||
"audio_adapter.model.{bid}" # lfm2
|
||||
"audio_adapter.model.{bid}", # lfm2
|
||||
"audio_tower.proj{bid}", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
|
|
|
|||
|
|
@ -8397,6 +8397,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 512, 1, 1}, order)); // test CUDA dispatching to radix sort for nrows > = 1 in graph mode
|
||||
}
|
||||
|
||||
for (int n = 1; n < 5; ++n) {
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ add_library(mtmd
|
|||
models/pixtral.cpp
|
||||
models/qwen2vl.cpp
|
||||
models/qwen3vl.cpp
|
||||
models/qwen3a.cpp
|
||||
models/step3vl.cpp
|
||||
models/siglip.cpp
|
||||
models/whisper-enc.cpp
|
||||
|
|
|
|||
|
|
@ -135,6 +135,8 @@
|
|||
|
||||
// ultravox
|
||||
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_CONV2D "a.conv2d.%d.%s"
|
||||
#define TN_CONV_OUT "a.conv_out.%s"
|
||||
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||
|
|
@ -271,6 +273,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_INTERNVL,
|
||||
PROJECTOR_TYPE_LLAMA4,
|
||||
PROJECTOR_TYPE_QWEN2A,
|
||||
PROJECTOR_TYPE_QWEN3A,
|
||||
PROJECTOR_TYPE_GLMA,
|
||||
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
|
|
@ -315,6 +318,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
|
||||
{ PROJECTOR_TYPE_GLMA, "glma"},
|
||||
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
|
|
|
|||
|
|
@ -413,10 +413,20 @@ struct clip_model {
|
|||
ggml_tensor * conv1d_1_b = nullptr;
|
||||
ggml_tensor * conv1d_2_w = nullptr;
|
||||
ggml_tensor * conv1d_2_b = nullptr;
|
||||
ggml_tensor * conv_out_w = nullptr;
|
||||
ggml_tensor * conv_out_b = nullptr;
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_pre_b = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
|
||||
// qwen3a
|
||||
ggml_tensor * conv2d_1_w = nullptr;
|
||||
ggml_tensor * conv2d_1_b = nullptr;
|
||||
ggml_tensor * conv2d_2_w = nullptr;
|
||||
ggml_tensor * conv2d_2_b = nullptr;
|
||||
ggml_tensor * conv2d_3_w = nullptr;
|
||||
ggml_tensor * conv2d_3_b = nullptr;
|
||||
|
||||
// cogvlm
|
||||
ggml_tensor * mm_post_fc_norm_w = nullptr;
|
||||
ggml_tensor * mm_post_fc_norm_b = nullptr;
|
||||
|
|
|
|||
|
|
@ -939,6 +939,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
{
|
||||
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_qwen3a>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YOUTUVL:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
|
||||
|
|
@ -1402,6 +1406,7 @@ struct clip_model_loader {
|
|||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
|
|
@ -2072,6 +2077,20 @@ struct clip_model_loader {
|
|||
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
model.conv2d_1_w = get_tensor(string_format(TN_CONV2D, 1, "weight"));
|
||||
model.conv2d_1_b = get_tensor(string_format(TN_CONV2D, 1, "bias"));
|
||||
model.conv2d_2_w = get_tensor(string_format(TN_CONV2D, 2, "weight"));
|
||||
model.conv2d_2_b = get_tensor(string_format(TN_CONV2D, 2, "bias"));
|
||||
model.conv2d_3_w = get_tensor(string_format(TN_CONV2D, 3, "weight"));
|
||||
model.conv2d_3_b = get_tensor(string_format(TN_CONV2D, 3, "bias"));
|
||||
model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias
|
||||
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
|
||||
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
|
|
@ -2948,6 +2967,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
n_patches /= 2;
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
// 3x stride-2 conv2d: each step is floor((n-1)/2)+1
|
||||
int n = img->nx;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n_patches = n;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
n_patches = img->nx;
|
||||
|
|
@ -3424,6 +3452,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
|
|
@ -3653,8 +3682,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
return ctx->model.mm_model_proj->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
|
|
@ -3706,6 +3736,7 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
|||
switch (ctx->proj_type()) {
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
|
|
|
|||
|
|
@ -152,6 +152,11 @@ struct clip_graph_mobilenetv5 : clip_graph {
|
|||
const mobilenetv5_block & block);
|
||||
};
|
||||
|
||||
struct clip_graph_qwen3a : clip_graph {
|
||||
clip_graph_qwen3a(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_kimik25 : clip_graph {
|
||||
clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
#include "models.h"
|
||||
|
||||
ggml_cgraph * clip_graph_qwen3a::build() {
|
||||
ggml_tensor * inp = build_inp_raw(1);
|
||||
|
||||
// conv2d block
|
||||
// TODO: do we need to split by chunks of n_window each like on transformers impl?
|
||||
{
|
||||
inp = ggml_conv_2d(ctx0, model.conv2d_1_w, inp, 2, 2, 1, 1, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, model.conv2d_1_b);
|
||||
inp = ggml_gelu_erf(ctx0, inp);
|
||||
|
||||
inp = ggml_conv_2d(ctx0, model.conv2d_2_w, inp, 2, 2, 1, 1, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, model.conv2d_2_b);
|
||||
inp = ggml_gelu_erf(ctx0, inp);
|
||||
|
||||
inp = ggml_conv_2d(ctx0, model.conv2d_3_w, inp, 2, 2, 1, 1, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, model.conv2d_3_b);
|
||||
inp = ggml_gelu_erf(ctx0, inp);
|
||||
|
||||
// inp [n_pos, n_mels/8, channels, 1] (W, H, C, N)
|
||||
cb(inp, "after_conv_blocks", -1);
|
||||
|
||||
const int64_t n_pos_after_conv = inp->ne[0];
|
||||
const int64_t n_mel_after_conv = inp->ne[1]; // 128/8 = 16
|
||||
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 3, 1));
|
||||
inp = ggml_reshape_2d(ctx0, inp, n_pos_after_conv, n_mel_after_conv * inp->ne[3]); // [n_pos, 7680]
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [7680, n_pos]
|
||||
|
||||
// project to n_embd
|
||||
inp = ggml_mul_mat(ctx0, model.conv_out_w, inp);
|
||||
if (model.conv_out_b) {
|
||||
inp = ggml_add(ctx0, inp, model.conv_out_b);
|
||||
}
|
||||
cb(inp, "after_conv_out", -1);
|
||||
}
|
||||
|
||||
auto n_pos = inp->ne[1];
|
||||
|
||||
ggml_tensor * pos_embd_selected = ggml_view_2d(
|
||||
ctx0, model.position_embeddings,
|
||||
model.position_embeddings->ne[0], n_pos,
|
||||
model.position_embeddings->nb[1], 0
|
||||
);
|
||||
ggml_tensor * cur = build_vit(
|
||||
inp, n_pos,
|
||||
NORM_TYPE_NORMAL,
|
||||
hparams.ffn_op,
|
||||
pos_embd_selected,
|
||||
nullptr);
|
||||
|
||||
cb(cur, "after_transformer", -1);
|
||||
|
||||
// projector
|
||||
cur = build_ffn(cur,
|
||||
model.mm_1_w, model.mm_1_b,
|
||||
nullptr, nullptr,
|
||||
model.mm_2_w, model.mm_2_b,
|
||||
FFN_GELU_ERF,
|
||||
-1);
|
||||
|
||||
cb(cur, "projected", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
|
@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk);
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
|
|
@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -198,35 +198,38 @@ struct img_tool {
|
|||
private:
|
||||
// Bilinear resize function
|
||||
static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
GGML_ASSERT(src.nx >= 2 && src.ny >= 2);
|
||||
if (src.nx == 0 || src.ny == 0) { dst.nx = dst.ny = 0; dst.buf.clear(); return; }
|
||||
if (target_width <= 0) target_width = 1;
|
||||
if (target_height <= 0) target_height = 1;
|
||||
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
|
||||
float x_ratio = static_cast<float>(src.nx - 1) / target_width;
|
||||
float y_ratio = static_cast<float>(src.ny - 1) / target_height;
|
||||
float x_ratio = target_width > 1 ? static_cast<float>(src.nx - 1) / (target_width - 1) : 0.0f;
|
||||
float y_ratio = target_height > 1 ? static_cast<float>(src.ny - 1) / (target_height - 1) : 0.0f;
|
||||
|
||||
for (int y = 0; y < target_height; y++) {
|
||||
for (int x = 0; x < target_width; x++) {
|
||||
float px = x_ratio * x;
|
||||
float py = y_ratio * y;
|
||||
int x_floor = std::min(static_cast<int>(px), src.nx - 2);
|
||||
int y_floor = std::min(static_cast<int>(py), src.ny - 2);
|
||||
float x_lerp = px - x_floor;
|
||||
float y_lerp = py - y_floor;
|
||||
for (int y = 0; y < target_height; ++y) {
|
||||
for (int x = 0; x < target_width; ++x) {
|
||||
float px = x * x_ratio;
|
||||
float py = y * y_ratio;
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
float top = lerp(
|
||||
static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
|
||||
static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
|
||||
x_lerp
|
||||
);
|
||||
float bottom = lerp(
|
||||
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
|
||||
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
|
||||
x_lerp
|
||||
);
|
||||
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, y_lerp));
|
||||
int x0 = std::min(static_cast<int>(px), src.nx - 1);
|
||||
int y0 = std::min(static_cast<int>(py), src.ny - 1);
|
||||
int x1 = std::min(x0 + 1, src.nx - 1);
|
||||
int y1 = std::min(y0 + 1, src.ny - 1);
|
||||
|
||||
float xf = px - x0;
|
||||
float yf = py - y0;
|
||||
|
||||
for (int c = 0; c < 3; ++c) {
|
||||
float top = lerp(static_cast<float>(src.buf[3 * (y0 * src.nx + x0) + c]),
|
||||
static_cast<float>(src.buf[3 * (y0 * src.nx + x1) + c]),
|
||||
xf);
|
||||
float bottom = lerp(static_cast<float>(src.buf[3 * (y1 * src.nx + x0) + c]),
|
||||
static_cast<float>(src.buf[3 * (y1 * src.nx + x1) + c]),
|
||||
xf);
|
||||
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, yf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -455,6 +455,7 @@ struct mtmd_context {
|
|||
// set preprocessor
|
||||
switch (proj) {
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_QWEN25O:
|
||||
{
|
||||
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||
|
|
@ -1016,8 +1017,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
|||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
switch (ctx->proj_type_v()) {
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
auto proj_type = ctx->proj_type_v();
|
||||
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
proj_type = ctx->proj_type_a();
|
||||
}
|
||||
switch (proj_type) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
return true;
|
||||
|
|
@ -1027,6 +1032,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
|||
}
|
||||
|
||||
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
|
||||
if (ctx->ctx_v == nullptr && ctx->proj_type_a() == PROJECTOR_TYPE_QWEN3A) {
|
||||
// qwen3-asr
|
||||
return true;
|
||||
}
|
||||
switch (ctx->proj_type_v()) {
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
|
|
|
|||
|
|
@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
|||
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
|
||||
// whether we need to set non-causal mask before llama_decode
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||
// if chunk is nullptr, we assume the default case where chunk is an image chunk
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
|
||||
|
||||
// whether the current model use M-RoPE for llama_decode
|
||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||
|
|
|
|||
|
|
@ -91,11 +91,13 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
|
|||
add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr
|
||||
add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
|
@ -926,7 +926,8 @@ void server_models_routes::init_routes() {
|
|||
res_ok(res, {
|
||||
// TODO: add support for this on web UI
|
||||
{"role", "router"},
|
||||
{"max_instances", 4}, // dummy value for testing
|
||||
{"max_instances", params.models_max},
|
||||
{"models_autoload", params.models_autoload},
|
||||
// this is a dummy response to make sure webui doesn't break
|
||||
{"model_alias", "llama-server"},
|
||||
{"model_path", "none"},
|
||||
|
|
@ -935,6 +936,7 @@ void server_models_routes::init_routes() {
|
|||
{"n_ctx", 0},
|
||||
}},
|
||||
{"webui_settings", webui_settings},
|
||||
{"build_info", build_info},
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,19 @@ def create_server():
|
|||
server = ServerPreset.router()
|
||||
|
||||
|
||||
def test_router_props():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["role"] == "router"
|
||||
assert res.body["max_instances"] == 2
|
||||
assert res.body["models_autoload"] is False
|
||||
assert res.body["build_info"].startswith("b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,success",
|
||||
[
|
||||
|
|
|
|||
|
|
@ -89,6 +89,11 @@
|
|||
key: SETTINGS_KEYS.ASK_FOR_TITLE_CONFIRMATION,
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.TITLE_GENERATION_USE_FIRST_LINE,
|
||||
label: 'Use first non-empty line for conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -15,6 +15,18 @@
|
|||
let { logs, connectionTimeMs, defaultExpanded = false, class: className }: Props = $props();
|
||||
|
||||
let isExpanded = $derived(defaultExpanded);
|
||||
|
||||
function formatLogDetails(details: unknown): string {
|
||||
if (details == null) {
|
||||
return '';
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.stringify(details, null, 2);
|
||||
} catch {
|
||||
return String(details);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if logs.length > 0}
|
||||
|
|
@ -53,6 +65,16 @@
|
|||
|
||||
<span class="break-all">{log.message}</span>
|
||||
</div>
|
||||
|
||||
{#if log.details !== undefined}
|
||||
<details class="ml-11">
|
||||
<summary class="cursor-pointer text-[10px] text-muted-foreground"> details </summary>
|
||||
|
||||
<pre
|
||||
class="mt-1 overflow-x-auto rounded bg-background/70 p-2 text-[10px] break-all whitespace-pre-wrap text-foreground/80">
|
||||
{formatLogDetails(log.details)}</pre>
|
||||
</details>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</Collapsible.Content>
|
||||
|
|
|
|||
|
|
@ -48,6 +48,26 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
|
|||
/** CORS proxy URL query parameter name */
|
||||
export const CORS_PROXY_URL_PARAM = 'url';
|
||||
|
||||
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
||||
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
||||
|
||||
/** Partial-redaction rules for MCP headers: header name -> visible trailing chars */
|
||||
export const MCP_PARTIAL_REDACT_HEADERS = new Map<string, number>([
|
||||
['mcp-session-id', MCP_SESSION_ID_VISIBLE_CHARS]
|
||||
]);
|
||||
|
||||
/** Header names whose values should be redacted in diagnostic logs */
|
||||
export const REDACTED_HEADERS = new Set([
|
||||
'authorization',
|
||||
'api-key',
|
||||
'cookie',
|
||||
'mcp-session-id',
|
||||
'proxy-authorization',
|
||||
'set-cookie',
|
||||
'x-auth-token',
|
||||
'x-api-key'
|
||||
]);
|
||||
|
||||
/** Human-readable labels for MCP transport types */
|
||||
export const MCP_TRANSPORT_LABELS: Record<MCPTransportType, string> = {
|
||||
[MCPTransportType.WEBSOCKET]: 'WebSocket',
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
|
|||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
askForTitleConfirmation: false,
|
||||
titleGenerationUseFirstLine: false,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
copyTextAttachmentsAsPlainText: false,
|
||||
pdfAsImage: false,
|
||||
|
|
@ -118,6 +119,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
|
||||
askForTitleConfirmation:
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
titleGenerationUseFirstLine:
|
||||
'Use only the first non-empty line of the prompt to generate the conversation title.',
|
||||
pdfAsImage:
|
||||
'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.',
|
||||
disableAutoScroll:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ export const SETTINGS_KEYS = {
|
|||
ENABLE_CONTINUE_GENERATION: 'enableContinueGeneration',
|
||||
PDF_AS_IMAGE: 'pdfAsImage',
|
||||
ASK_FOR_TITLE_CONFIRMATION: 'askForTitleConfirmation',
|
||||
TITLE_GENERATION_USE_FIRST_LINE: 'titleGenerationUseFirstLine',
|
||||
// Display
|
||||
SHOW_MESSAGE_STATS: 'showMessageStats',
|
||||
SHOW_THOUGHT_IN_PROGRESS: 'showThoughtInProgress',
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
|||
import {
|
||||
DEFAULT_MCP_CONFIG,
|
||||
DEFAULT_CLIENT_VERSION,
|
||||
DEFAULT_IMAGE_MIME_TYPE
|
||||
DEFAULT_IMAGE_MIME_TYPE,
|
||||
MCP_PARTIAL_REDACT_HEADERS
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
MCPConnectionPhase,
|
||||
|
|
@ -43,9 +44,17 @@ import {
|
|||
buildProxiedUrl,
|
||||
buildProxiedHeaders,
|
||||
getAuthHeaders,
|
||||
sanitizeHeaders,
|
||||
throwIfAborted,
|
||||
isAbortError,
|
||||
createBase64DataUrl
|
||||
createBase64DataUrl,
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods,
|
||||
type RequestBodySummary
|
||||
} from '$lib/utils';
|
||||
|
||||
interface ToolResultContentItem {
|
||||
|
|
@ -62,6 +71,16 @@ interface ToolCallResult {
|
|||
_meta?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface DiagnosticRequestDetails {
|
||||
url: string;
|
||||
method: string;
|
||||
credentials?: RequestCredentials;
|
||||
mode?: RequestMode;
|
||||
headers: Record<string, string>;
|
||||
body: RequestBodySummary;
|
||||
jsonRpcMethods?: string[];
|
||||
}
|
||||
|
||||
export class MCPService {
|
||||
/**
|
||||
* Create a connection log entry for phase tracking.
|
||||
|
|
@ -87,6 +106,225 @@ export class MCPService {
|
|||
};
|
||||
}
|
||||
|
||||
private static createDiagnosticRequestDetails(
|
||||
input: RequestInfo | URL,
|
||||
init: RequestInit | undefined,
|
||||
baseInit: RequestInit,
|
||||
requestHeaders: Headers,
|
||||
extraRedactedHeaders?: Iterable<string>
|
||||
): DiagnosticRequestDetails {
|
||||
const body = getRequestBody(input, init);
|
||||
const details: DiagnosticRequestDetails = {
|
||||
url: getRequestUrl(input),
|
||||
method: getRequestMethod(input, init, baseInit).toUpperCase(),
|
||||
credentials: init?.credentials ?? baseInit.credentials,
|
||||
mode: init?.mode ?? baseInit.mode,
|
||||
headers: sanitizeHeaders(requestHeaders, extraRedactedHeaders, MCP_PARTIAL_REDACT_HEADERS),
|
||||
body: summarizeRequestBody(body)
|
||||
};
|
||||
const jsonRpcMethods = extractJsonRpcMethods(body);
|
||||
|
||||
if (jsonRpcMethods) {
|
||||
details.jsonRpcMethods = jsonRpcMethods;
|
||||
}
|
||||
|
||||
return details;
|
||||
}
|
||||
|
||||
private static summarizeError(error: unknown): Record<string, unknown> {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
cause:
|
||||
error.cause instanceof Error
|
||||
? { name: error.cause.name, message: error.cause.message }
|
||||
: error.cause,
|
||||
stack: error.stack?.split('\n').slice(0, 6).join('\n')
|
||||
};
|
||||
}
|
||||
|
||||
return { value: String(error) };
|
||||
}
|
||||
|
||||
private static getBrowserContext(
|
||||
targetUrl: URL,
|
||||
useProxy: boolean
|
||||
): Record<string, unknown> | undefined {
|
||||
if (typeof window === 'undefined') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
location: window.location.href,
|
||||
origin: window.location.origin,
|
||||
protocol: window.location.protocol,
|
||||
isSecureContext: window.isSecureContext,
|
||||
targetOrigin: targetUrl.origin,
|
||||
targetProtocol: targetUrl.protocol,
|
||||
sameOrigin: window.location.origin === targetUrl.origin,
|
||||
useProxy
|
||||
};
|
||||
}
|
||||
|
||||
private static getConnectionHints(
|
||||
targetUrl: URL,
|
||||
config: MCPServerConfig,
|
||||
error: unknown
|
||||
): string[] {
|
||||
const hints: string[] = [];
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const headerNames = Object.keys(config.headers ?? {});
|
||||
|
||||
if (typeof window !== 'undefined') {
|
||||
if (
|
||||
window.location.protocol === 'https:' &&
|
||||
targetUrl.protocol === 'http:' &&
|
||||
!config.useProxy
|
||||
) {
|
||||
hints.push(
|
||||
'The page is running over HTTPS but the MCP server is HTTP. Browsers often block this as mixed content; enable the proxy or use HTTPS/WSS for the MCP server.'
|
||||
);
|
||||
}
|
||||
|
||||
if (window.location.origin !== targetUrl.origin && !config.useProxy) {
|
||||
hints.push(
|
||||
'This is a cross-origin browser request. If the server is reachable from curl or Node but not from the browser, missing CORS headers are the most likely cause.'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (headerNames.length > 0) {
|
||||
hints.push(
|
||||
`Custom request headers are configured (${headerNames.join(', ')}). That triggers a CORS preflight, so the server must allow OPTIONS and include the matching Access-Control-Allow-Headers response.`
|
||||
);
|
||||
}
|
||||
|
||||
if (config.credentials && config.credentials !== 'omit') {
|
||||
hints.push(
|
||||
'Credentials are enabled for this connection. Cross-origin credentialed requests need Access-Control-Allow-Credentials: true and cannot use a wildcard Access-Control-Allow-Origin.'
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('Failed to fetch')) {
|
||||
hints.push(
|
||||
'"Failed to fetch" is a browser-level network failure. Common causes are CORS rejection, mixed-content blocking, certificate/TLS errors, DNS failures, or nothing listening on the target port.'
|
||||
);
|
||||
}
|
||||
|
||||
return hints;
|
||||
}
|
||||
|
||||
private static createDiagnosticFetch(
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
baseInit: RequestInit,
|
||||
targetUrl: URL,
|
||||
useProxy: boolean,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
): {
|
||||
fetch: typeof fetch;
|
||||
disable: () => void;
|
||||
} {
|
||||
let enabled = true;
|
||||
const logIfEnabled = (log: MCPConnectionLog) => {
|
||||
if (enabled) {
|
||||
onLog?.(log);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
fetch: async (input, init) => {
|
||||
const startedAt = performance.now();
|
||||
const requestHeaders = new Headers(baseInit.headers);
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
for (const [key, value] of input.headers.entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
if (init?.headers) {
|
||||
for (const [key, value] of new Headers(init.headers).entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
const request = this.createDiagnosticRequestDetails(
|
||||
input,
|
||||
init,
|
||||
baseInit,
|
||||
requestHeaders,
|
||||
Object.keys(config.headers ?? {})
|
||||
);
|
||||
const { method, url } = request;
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
`HTTP ${method} ${url}`,
|
||||
MCPLogLevel.INFO,
|
||||
{
|
||||
serverName,
|
||||
request
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
try {
|
||||
const response = await fetch(input, {
|
||||
...baseInit,
|
||||
...init,
|
||||
headers: requestHeaders
|
||||
});
|
||||
const durationMs = Math.round(performance.now() - startedAt);
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
`HTTP ${response.status} ${method} ${url} (${durationMs}ms)`,
|
||||
response.ok ? MCPLogLevel.INFO : MCPLogLevel.WARN,
|
||||
{
|
||||
response: {
|
||||
url,
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
headers: sanitizeHeaders(response.headers, undefined, MCP_PARTIAL_REDACT_HEADERS),
|
||||
durationMs
|
||||
}
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
const durationMs = Math.round(performance.now() - startedAt);
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`HTTP ${method} ${url} failed: ${formatDiagnosticErrorMessage(error)}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
serverName,
|
||||
request,
|
||||
error: this.summarizeError(error),
|
||||
browser: this.getBrowserContext(targetUrl, useProxy),
|
||||
hints: this.getConnectionHints(targetUrl, config, error),
|
||||
durationMs
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
disable: () => {
|
||||
enabled = false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if an error indicates an expired/invalidated MCP session.
|
||||
* Per MCP spec 2025-11-25: HTTP 404 means session invalidated, client MUST
|
||||
|
|
@ -113,9 +351,14 @@ export class MCPService {
|
|||
* @returns Object containing the created transport and the transport type used
|
||||
* @throws {Error} If url is missing, WebSocket + proxy combination, or all transports fail
|
||||
*/
|
||||
static createTransport(config: MCPServerConfig): {
|
||||
static createTransport(
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
): {
|
||||
transport: Transport;
|
||||
type: MCPTransportType;
|
||||
stopPhaseLogging: () => void;
|
||||
} {
|
||||
if (!config.url) {
|
||||
throw new Error('MCP server configuration is missing url');
|
||||
|
|
@ -154,11 +397,20 @@ export class MCPService {
|
|||
|
||||
return {
|
||||
transport: new WebSocketClientTransport(url),
|
||||
type: MCPTransportType.WEBSOCKET
|
||||
type: MCPTransportType.WEBSOCKET,
|
||||
stopPhaseLogging: () => {}
|
||||
};
|
||||
}
|
||||
|
||||
const url = useProxy ? buildProxiedUrl(config.url) : new URL(config.url);
|
||||
const { fetch: diagnosticFetch, disable: stopPhaseLogging } = this.createDiagnosticFetch(
|
||||
serverName,
|
||||
config,
|
||||
requestInit,
|
||||
url,
|
||||
useProxy,
|
||||
onLog
|
||||
);
|
||||
|
||||
if (useProxy && import.meta.env.DEV) {
|
||||
console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`);
|
||||
|
|
@ -171,17 +423,24 @@ export class MCPService {
|
|||
|
||||
return {
|
||||
transport: new StreamableHTTPClientTransport(url, {
|
||||
requestInit
|
||||
requestInit,
|
||||
fetch: diagnosticFetch
|
||||
}),
|
||||
type: MCPTransportType.STREAMABLE_HTTP
|
||||
type: MCPTransportType.STREAMABLE_HTTP,
|
||||
stopPhaseLogging
|
||||
};
|
||||
} catch (httpError) {
|
||||
console.warn(`[MCPService] StreamableHTTP failed, trying SSE transport...`, httpError);
|
||||
|
||||
try {
|
||||
return {
|
||||
transport: new SSEClientTransport(url, { requestInit }),
|
||||
type: MCPTransportType.SSE
|
||||
transport: new SSEClientTransport(url, {
|
||||
requestInit,
|
||||
fetch: diagnosticFetch,
|
||||
eventSourceInit: { fetch: diagnosticFetch }
|
||||
}),
|
||||
type: MCPTransportType.SSE,
|
||||
stopPhaseLogging
|
||||
};
|
||||
} catch (sseError) {
|
||||
const httpMsg = httpError instanceof Error ? httpError.message : String(httpError);
|
||||
|
|
@ -263,7 +522,11 @@ export class MCPService {
|
|||
console.log(`[MCPService][${serverName}] Creating transport...`);
|
||||
}
|
||||
|
||||
const { transport, type: transportType } = this.createTransport(serverConfig);
|
||||
const {
|
||||
transport,
|
||||
type: transportType,
|
||||
stopPhaseLogging
|
||||
} = this.createTransport(serverName, serverConfig, (log) => onPhase?.(log.phase, log));
|
||||
|
||||
// Setup WebSocket reconnection handler
|
||||
if (transportType === MCPTransportType.WEBSOCKET) {
|
||||
|
|
@ -294,6 +557,24 @@ export class MCPService {
|
|||
}
|
||||
);
|
||||
|
||||
const runtimeErrorHandler = (error: Error) => {
|
||||
console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error);
|
||||
};
|
||||
|
||||
client.onerror = (error) => {
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.ERROR,
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`Protocol error: ${error.message}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
error: this.summarizeError(error)
|
||||
}
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
// Phase: Initializing
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
|
|
@ -301,7 +582,49 @@ export class MCPService {
|
|||
);
|
||||
|
||||
console.log(`[MCPService][${serverName}] Connecting to server...`);
|
||||
await client.connect(transport);
|
||||
try {
|
||||
await client.connect(transport);
|
||||
// Transport diagnostics are only for the initial handshake, not long-lived traffic.
|
||||
stopPhaseLogging();
|
||||
client.onerror = runtimeErrorHandler;
|
||||
} catch (error) {
|
||||
client.onerror = runtimeErrorHandler;
|
||||
const url =
|
||||
(serverConfig.useProxy ?? false)
|
||||
? buildProxiedUrl(serverConfig.url)
|
||||
: new URL(serverConfig.url);
|
||||
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.ERROR,
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`Connection failed during initialize: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
error: this.summarizeError(error),
|
||||
config: {
|
||||
serverName,
|
||||
configuredUrl: serverConfig.url,
|
||||
effectiveUrl: url.href,
|
||||
transportType,
|
||||
useProxy: serverConfig.useProxy ?? false,
|
||||
headers: sanitizeHeaders(
|
||||
serverConfig.headers,
|
||||
Object.keys(serverConfig.headers ?? {}),
|
||||
MCP_PARTIAL_REDACT_HEADERS
|
||||
),
|
||||
credentials: serverConfig.credentials
|
||||
},
|
||||
browser: this.getBrowserContext(url, serverConfig.useProxy ?? false),
|
||||
hints: this.getConnectionHints(url, serverConfig, error)
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
|
||||
const serverVersion = client.getServerVersion();
|
||||
const serverCapabilities = client.getServerCapabilities();
|
||||
|
|
|
|||
|
|
@ -130,6 +130,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
|||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'titleGenerationUseFirstLine',
|
||||
serverKey: 'titleGenerationUseFirstLine',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'disableAutoScroll',
|
||||
serverKey: 'disableAutoScroll',
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ import {
|
|||
findDescendantMessages,
|
||||
findLeafNode,
|
||||
findMessageById,
|
||||
isAbortError
|
||||
isAbortError,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import {
|
||||
MAX_INACTIVE_CONVERSATION_STATES,
|
||||
|
|
@ -504,7 +505,10 @@ class ChatStore {
|
|||
allExtras
|
||||
);
|
||||
if (isNewConversation && content)
|
||||
await conversationsStore.updateConversationName(currentConv.id, content.trim());
|
||||
await conversationsStore.updateConversationName(
|
||||
currentConv.id,
|
||||
generateConversationTitle(content, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
const assistantMessage = await this.createAssistantMessage(userMessage.id);
|
||||
conversationsStore.addMessageToActive(assistantMessage);
|
||||
await this.streamChatCompletion(
|
||||
|
|
@ -896,7 +900,7 @@ class ChatStore {
|
|||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1);
|
||||
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
|
||||
|
|
@ -1317,7 +1321,7 @@ class ChatStore {
|
|||
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -1391,7 +1395,7 @@ class ChatStore {
|
|||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
if (msg.role === MessageRole.USER)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,12 @@ import { browser } from '$app/environment';
|
|||
import { toast } from 'svelte-sonner';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { filterByLeafNodeId, findLeafNode, runLegacyMigration } from '$lib/utils';
|
||||
import {
|
||||
filterByLeafNodeId,
|
||||
findLeafNode,
|
||||
runLegacyMigration,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import type { McpServerOverride } from '$lib/types/database';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import {
|
||||
|
|
@ -548,7 +553,10 @@ class ConversationsStore {
|
|||
) {
|
||||
await this.updateConversationTitleWithConfirmation(
|
||||
this.activeConversation.id,
|
||||
newFirstUserMessage.content.trim()
|
||||
generateConversationTitle(
|
||||
newFirstUserMessage.content,
|
||||
Boolean(config().titleGenerationUseFirstLine)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1460,12 +1460,14 @@ class MCPStore {
|
|||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Unknown error occurred';
|
||||
|
||||
logs.push({
|
||||
timestamp: new Date(),
|
||||
phase: MCPConnectionPhase.ERROR,
|
||||
message: `Connection failed: ${message}`,
|
||||
level: MCPLogLevel.ERROR
|
||||
});
|
||||
if (logs.at(-1)?.phase !== MCPConnectionPhase.ERROR) {
|
||||
logs.push({
|
||||
timestamp: new Date(),
|
||||
phase: MCPConnectionPhase.ERROR,
|
||||
message: `Connection failed: ${message}`,
|
||||
level: MCPLogLevel.ERROR
|
||||
});
|
||||
}
|
||||
|
||||
this.updateHealthCheck(server.id, {
|
||||
status: HealthCheckStatus.ERROR,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { REDACTED_HEADERS } from '$lib/constants';
|
||||
import { redactValue } from './redact';
|
||||
|
||||
/**
|
||||
* Get authorization headers for API requests
|
||||
|
|
@ -20,3 +22,46 @@ export function getJsonHeaders(): Record<string, string> {
|
|||
...getAuthHeaders()
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize HTTP headers by redacting sensitive values.
|
||||
* Known sensitive headers (from REDACTED_HEADERS) and any extra headers
|
||||
* specified by the caller are fully redacted. Headers listed in
|
||||
* `partialRedactHeaders` are partially redacted, showing only the
|
||||
* specified number of trailing characters.
|
||||
*
|
||||
* @param headers - Headers to sanitize
|
||||
* @param extraRedactedHeaders - Additional header names to fully redact
|
||||
* @param partialRedactHeaders - Map of header name -> number of trailing chars to keep visible
|
||||
* @returns Object with header names as keys and (possibly redacted) values
|
||||
*/
|
||||
export function sanitizeHeaders(
|
||||
headers?: HeadersInit,
|
||||
extraRedactedHeaders?: Iterable<string>,
|
||||
partialRedactHeaders?: Map<string, number>
|
||||
): Record<string, string> {
|
||||
if (!headers) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const normalized = new Headers(headers);
|
||||
const sanitized: Record<string, string> = {};
|
||||
const redactedHeaders = new Set(
|
||||
Array.from(extraRedactedHeaders ?? [], (header) => header.toLowerCase())
|
||||
);
|
||||
|
||||
for (const [key, value] of normalized.entries()) {
|
||||
const normalizedKey = key.toLowerCase();
|
||||
const partialChars = partialRedactHeaders?.get(normalizedKey);
|
||||
|
||||
if (partialChars !== undefined) {
|
||||
sanitized[key] = redactValue(value, partialChars);
|
||||
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
|
||||
sanitized[key] = redactValue(value);
|
||||
} else {
|
||||
sanitized[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
*/
|
||||
|
||||
// API utilities
|
||||
export { getAuthHeaders, getJsonHeaders } from './api-headers';
|
||||
export { getAuthHeaders, getJsonHeaders, sanitizeHeaders } from './api-headers';
|
||||
export { apiFetch, apiFetchWithParams, apiPost, type ApiFetchOptions } from './api-fetch';
|
||||
export { validateApiKey } from './api-key-validation';
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ export {
|
|||
|
||||
// File preview utilities
|
||||
export { getFileTypeLabel } from './file-preview';
|
||||
export { getPreviewText } from './text';
|
||||
export { getPreviewText, generateConversationTitle } from './text';
|
||||
|
||||
// File type utilities
|
||||
export {
|
||||
|
|
@ -164,6 +164,20 @@ export { runLegacyMigration, isMigrationNeeded } from './legacy-migration';
|
|||
// Cache utilities
|
||||
export { TTLCache, ReactiveTTLMap, type TTLCacheOptions } from './cache-ttl';
|
||||
|
||||
// Redaction utilities
|
||||
export { redactValue } from './redact';
|
||||
|
||||
// Request inspection utilities
|
||||
export {
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods,
|
||||
type RequestBodySummary
|
||||
} from './request-helpers';
|
||||
|
||||
// Abort signal utilities
|
||||
export {
|
||||
throwIfAborted,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
/**
|
||||
* Redacts a sensitive value, optionally showing the last N characters.
|
||||
*
|
||||
* @param value - The value to redact
|
||||
* @param showLastChars - If provided, reveals the last N characters with a leading mask
|
||||
* @returns The redacted string
|
||||
*/
|
||||
export function redactValue(value: string, showLastChars?: number): string {
|
||||
if (showLastChars) {
|
||||
return `....${value.slice(-showLastChars)}`;
|
||||
}
|
||||
|
||||
return '[redacted]';
|
||||
}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* HTTP request inspection utilities for diagnostic logging.
|
||||
* These helpers extract metadata from fetch-style request arguments
|
||||
* without exposing sensitive payload data.
|
||||
*/
|
||||
|
||||
export interface RequestBodySummary {
|
||||
kind: string;
|
||||
size?: number;
|
||||
}
|
||||
|
||||
export function getRequestUrl(input: RequestInfo | URL): string {
|
||||
if (typeof input === 'string') {
|
||||
return input;
|
||||
}
|
||||
|
||||
if (input instanceof URL) {
|
||||
return input.href;
|
||||
}
|
||||
|
||||
return input.url;
|
||||
}
|
||||
|
||||
export function getRequestMethod(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
baseInit?: RequestInit
|
||||
): string {
|
||||
if (init?.method) {
|
||||
return init.method;
|
||||
}
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
return input.method;
|
||||
}
|
||||
|
||||
return baseInit?.method ?? 'GET';
|
||||
}
|
||||
|
||||
export function getRequestBody(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit
|
||||
): BodyInit | null | undefined {
|
||||
if (init?.body !== undefined) {
|
||||
return init.body;
|
||||
}
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
return input.body;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function summarizeRequestBody(body: BodyInit | null | undefined): RequestBodySummary {
|
||||
if (body == null) {
|
||||
return { kind: 'empty' };
|
||||
}
|
||||
|
||||
if (typeof body === 'string') {
|
||||
return { kind: 'string', size: body.length };
|
||||
}
|
||||
|
||||
if (body instanceof Blob) {
|
||||
return { kind: 'blob', size: body.size };
|
||||
}
|
||||
|
||||
if (body instanceof URLSearchParams) {
|
||||
return { kind: 'urlsearchparams', size: body.toString().length };
|
||||
}
|
||||
|
||||
if (body instanceof FormData) {
|
||||
return { kind: 'formdata' };
|
||||
}
|
||||
|
||||
if (body instanceof ArrayBuffer) {
|
||||
return { kind: 'arraybuffer', size: body.byteLength };
|
||||
}
|
||||
|
||||
if (ArrayBuffer.isView(body)) {
|
||||
return { kind: body.constructor.name, size: body.byteLength };
|
||||
}
|
||||
|
||||
return { kind: typeof body };
|
||||
}
|
||||
|
||||
export function formatDiagnosticErrorMessage(error: unknown): string {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
|
||||
return message.includes('Failed to fetch') ? `${message} (check CORS?)` : message;
|
||||
}
|
||||
|
||||
export function extractJsonRpcMethods(body: BodyInit | null | undefined): string[] | undefined {
|
||||
if (typeof body !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(body);
|
||||
const messages = Array.isArray(parsed) ? parsed : [parsed];
|
||||
const methods = messages
|
||||
.map((message: Record<string, unknown>) =>
|
||||
typeof message?.method === 'string' ? (message.method as string) : undefined
|
||||
)
|
||||
.filter((method: string | undefined): method is string => Boolean(method));
|
||||
|
||||
return methods.length > 0 ? methods : undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import { NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Returns a shortened preview of the provided content capped at the given length.
|
||||
* Appends an ellipsis when the content exceeds the maximum.
|
||||
|
|
@ -5,3 +7,16 @@
|
|||
export function getPreviewText(content: string, max = 150): string {
|
||||
return content.length > max ? content.slice(0, max) + '...' : content;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a single-line title from a potentially multi-line prompt.
|
||||
* Uses the first non-empty line if `useFirstLine` is true.
|
||||
*/
|
||||
export function generateConversationTitle(content: string, useFirstLine: boolean = false): string {
|
||||
if (useFirstLine) {
|
||||
const firstLine = content.split(NEWLINE_SEPARATOR).find((line) => line.trim().length > 0);
|
||||
return firstLine ? firstLine.trim() : content.trim();
|
||||
}
|
||||
|
||||
return content.trim();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,252 @@
|
|||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client';
|
||||
import { MCPService } from '$lib/services/mcp.service';
|
||||
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
|
||||
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
|
||||
|
||||
type DiagnosticFetchFactory = (
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
baseInit: RequestInit,
|
||||
targetUrl: URL,
|
||||
useProxy: boolean,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
) => { fetch: typeof fetch; disable: () => void };
|
||||
|
||||
const createDiagnosticFetch = (
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void,
|
||||
baseInit: RequestInit = {}
|
||||
) =>
|
||||
(
|
||||
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
|
||||
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
|
||||
|
||||
describe('MCPService', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('stops transport phase logging after handshake diagnostics are disabled', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, { method: 'POST', body: '{}' });
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs.every((log) => log.message.includes('https://example.com/mcp'))).toBe(true);
|
||||
|
||||
controller.disable();
|
||||
await controller.fetch(config.url, { method: 'POST', body: '{}' });
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('redacts all configured custom headers in diagnostic request logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP,
|
||||
headers: {
|
||||
'x-auth-token': 'secret-token',
|
||||
'x-vendor-api-key': 'secret-key'
|
||||
}
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log), {
|
||||
headers: config.headers
|
||||
});
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
headers: { 'content-type': 'application/json' },
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
'x-auth-token': '[redacted]',
|
||||
'x-vendor-api-key': '[redacted]',
|
||||
'content-type': 'application/json'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-response-67890'
|
||||
}
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-request-12345'
|
||||
},
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': '....12345'
|
||||
}
|
||||
}
|
||||
});
|
||||
expect(logs[1].details).toMatchObject({
|
||||
response: {
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': '....67890'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('extracts JSON-RPC methods without logging the raw request body', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify([
|
||||
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
|
||||
{ jsonrpc: '2.0', method: 'notifications/initialized' }
|
||||
])
|
||||
});
|
||||
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
method: 'POST',
|
||||
body: {
|
||||
kind: 'string',
|
||||
size: expect.any(Number)
|
||||
},
|
||||
jsonRpcMethods: ['initialize', 'notifications/initialized']
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('adds a CORS hint to Failed to fetch diagnostic log messages', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const fetchError = new TypeError('Failed to fetch');
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockRejectedValue(fetchError));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'http://localhost:8000/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await expect(controller.fetch(config.url, { method: 'POST', body: '{}' })).rejects.toThrow(
|
||||
'Failed to fetch'
|
||||
);
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[1].message).toBe(
|
||||
'HTTP POST http://localhost:8000/mcp failed: Failed to fetch (check CORS?)'
|
||||
);
|
||||
});
|
||||
|
||||
it('detaches phase error logging after the initialize handshake completes', async () => {
|
||||
const phaseLogs: Array<{ phase: MCPConnectionPhase; log: MCPConnectionLog }> = [];
|
||||
const stopPhaseLogging = vi.fn();
|
||||
let emitClientError: ((error: Error) => void) | undefined;
|
||||
|
||||
vi.spyOn(MCPService, 'createTransport').mockReturnValue({
|
||||
transport: {} as never,
|
||||
type: MCPTransportType.WEBSOCKET,
|
||||
stopPhaseLogging
|
||||
});
|
||||
vi.spyOn(MCPService, 'listTools').mockResolvedValue([]);
|
||||
vi.spyOn(Client.prototype, 'getServerVersion').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'getServerCapabilities').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'getInstructions').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'connect').mockImplementation(async function (this: Client) {
|
||||
emitClientError = (error: Error) => this.onerror?.(error);
|
||||
this.onerror?.(new Error('handshake protocol error'));
|
||||
});
|
||||
|
||||
await MCPService.connect(
|
||||
'test-server',
|
||||
{
|
||||
url: 'ws://example.com/mcp',
|
||||
transport: MCPTransportType.WEBSOCKET
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
(phase, log) => phaseLogs.push({ phase, log })
|
||||
);
|
||||
|
||||
expect(stopPhaseLogging).toHaveBeenCalledTimes(1);
|
||||
expect(
|
||||
phaseLogs.filter(
|
||||
({ phase, log }) =>
|
||||
phase === MCPConnectionPhase.ERROR &&
|
||||
log.message === 'Protocol error: handshake protocol error'
|
||||
)
|
||||
).toHaveLength(1);
|
||||
|
||||
emitClientError?.(new Error('runtime protocol error'));
|
||||
|
||||
expect(
|
||||
phaseLogs.filter(
|
||||
({ phase, log }) =>
|
||||
phase === MCPConnectionPhase.ERROR &&
|
||||
log.message === 'Protocol error: runtime protocol error'
|
||||
)
|
||||
).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import { describe, expect, it } from 'vitest';
|
||||
import { redactValue } from '$lib/utils/redact';
|
||||
|
||||
describe('redactValue', () => {
|
||||
it('returns [redacted] by default', () => {
|
||||
expect(redactValue('secret-token')).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('shows last N characters when showLastChars is provided', () => {
|
||||
expect(redactValue('session-abc12', 5)).toBe('....abc12');
|
||||
});
|
||||
|
||||
it('handles value shorter than showLastChars', () => {
|
||||
expect(redactValue('ab', 5)).toBe('....ab');
|
||||
});
|
||||
|
||||
it('returns [redacted] when showLastChars is 0', () => {
|
||||
expect(redactValue('secret', 0)).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
import { describe, expect, it } from 'vitest';
|
||||
import {
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods
|
||||
} from '$lib/utils/request-helpers';
|
||||
|
||||
describe('getRequestUrl', () => {
|
||||
it('returns a plain string input as-is', () => {
|
||||
expect(getRequestUrl('https://example.com/mcp')).toBe('https://example.com/mcp');
|
||||
});
|
||||
|
||||
it('returns href from a URL object', () => {
|
||||
expect(getRequestUrl(new URL('https://example.com/mcp'))).toBe('https://example.com/mcp');
|
||||
});
|
||||
|
||||
it('returns url from a Request object', () => {
|
||||
const req = new Request('https://example.com/mcp');
|
||||
expect(getRequestUrl(req)).toBe('https://example.com/mcp');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRequestMethod', () => {
|
||||
it('prefers method from init', () => {
|
||||
expect(getRequestMethod('https://example.com', { method: 'POST' })).toBe('POST');
|
||||
});
|
||||
|
||||
it('falls back to Request.method', () => {
|
||||
const req = new Request('https://example.com', { method: 'PUT' });
|
||||
expect(getRequestMethod(req)).toBe('PUT');
|
||||
});
|
||||
|
||||
it('falls back to baseInit.method', () => {
|
||||
expect(getRequestMethod('https://example.com', undefined, { method: 'DELETE' })).toBe('DELETE');
|
||||
});
|
||||
|
||||
it('defaults to GET', () => {
|
||||
expect(getRequestMethod('https://example.com')).toBe('GET');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRequestBody', () => {
|
||||
it('returns body from init', () => {
|
||||
expect(getRequestBody('https://example.com', { body: 'payload' })).toBe('payload');
|
||||
});
|
||||
|
||||
it('returns undefined when no body is present', () => {
|
||||
expect(getRequestBody('https://example.com')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('summarizeRequestBody', () => {
|
||||
it('returns empty for null', () => {
|
||||
expect(summarizeRequestBody(null)).toEqual({ kind: 'empty' });
|
||||
});
|
||||
|
||||
it('returns empty for undefined', () => {
|
||||
expect(summarizeRequestBody(undefined)).toEqual({ kind: 'empty' });
|
||||
});
|
||||
|
||||
it('returns string kind with size', () => {
|
||||
expect(summarizeRequestBody('hello')).toEqual({ kind: 'string', size: 5 });
|
||||
});
|
||||
|
||||
it('returns blob kind with size', () => {
|
||||
const blob = new Blob(['abc']);
|
||||
expect(summarizeRequestBody(blob)).toEqual({ kind: 'blob', size: 3 });
|
||||
});
|
||||
|
||||
it('returns formdata kind', () => {
|
||||
expect(summarizeRequestBody(new FormData())).toEqual({ kind: 'formdata' });
|
||||
});
|
||||
|
||||
it('returns arraybuffer kind with size', () => {
|
||||
expect(summarizeRequestBody(new ArrayBuffer(8))).toEqual({ kind: 'arraybuffer', size: 8 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatDiagnosticErrorMessage', () => {
|
||||
it('appends CORS hint for Failed to fetch', () => {
|
||||
expect(formatDiagnosticErrorMessage(new TypeError('Failed to fetch'))).toBe(
|
||||
'Failed to fetch (check CORS?)'
|
||||
);
|
||||
});
|
||||
|
||||
it('passes through other error messages unchanged', () => {
|
||||
expect(formatDiagnosticErrorMessage(new Error('timeout'))).toBe('timeout');
|
||||
});
|
||||
|
||||
it('handles non-Error values', () => {
|
||||
expect(formatDiagnosticErrorMessage('some string')).toBe('some string');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractJsonRpcMethods', () => {
|
||||
it('extracts methods from a JSON-RPC array', () => {
|
||||
const body = JSON.stringify([
|
||||
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
|
||||
{ jsonrpc: '2.0', method: 'notifications/initialized' }
|
||||
]);
|
||||
expect(extractJsonRpcMethods(body)).toEqual(['initialize', 'notifications/initialized']);
|
||||
});
|
||||
|
||||
it('extracts method from a single JSON-RPC message', () => {
|
||||
const body = JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' });
|
||||
expect(extractJsonRpcMethods(body)).toEqual(['tools/list']);
|
||||
});
|
||||
|
||||
it('returns undefined for non-string body', () => {
|
||||
expect(extractJsonRpcMethods(null)).toBeUndefined();
|
||||
expect(extractJsonRpcMethods(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined for invalid JSON', () => {
|
||||
expect(extractJsonRpcMethods('not json')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined when no methods found', () => {
|
||||
expect(extractJsonRpcMethods(JSON.stringify({ foo: 'bar' }))).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import { describe, expect, it } from 'vitest';
|
||||
import { sanitizeHeaders } from '$lib/utils/api-headers';
|
||||
|
||||
describe('sanitizeHeaders', () => {
|
||||
it('returns empty object for undefined input', () => {
|
||||
expect(sanitizeHeaders()).toEqual({});
|
||||
});
|
||||
|
||||
it('passes through non-sensitive headers', () => {
|
||||
const headers = new Headers({ 'content-type': 'application/json', accept: 'text/html' });
|
||||
expect(sanitizeHeaders(headers)).toEqual({
|
||||
'content-type': 'application/json',
|
||||
accept: 'text/html'
|
||||
});
|
||||
});
|
||||
|
||||
it('redacts known sensitive headers', () => {
|
||||
const headers = new Headers({
|
||||
authorization: 'Bearer secret',
|
||||
'x-api-key': 'key-123',
|
||||
'content-type': 'application/json'
|
||||
});
|
||||
const result = sanitizeHeaders(headers);
|
||||
expect(result.authorization).toBe('[redacted]');
|
||||
expect(result['x-api-key']).toBe('[redacted]');
|
||||
expect(result['content-type']).toBe('application/json');
|
||||
});
|
||||
|
||||
it('partially redacts headers specified in partialRedactHeaders', () => {
|
||||
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
|
||||
const partial = new Map([['mcp-session-id', 5]]);
|
||||
expect(sanitizeHeaders(headers, undefined, partial)['mcp-session-id']).toBe('....12345');
|
||||
});
|
||||
|
||||
it('fully redacts mcp-session-id when no partialRedactHeaders is given', () => {
|
||||
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
|
||||
expect(sanitizeHeaders(headers)['mcp-session-id']).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('redacts extra headers provided by the caller', () => {
|
||||
const headers = new Headers({
|
||||
'x-vendor-key': 'vendor-secret',
|
||||
'content-type': 'application/json'
|
||||
});
|
||||
const result = sanitizeHeaders(headers, ['x-vendor-key']);
|
||||
expect(result['x-vendor-key']).toBe('[redacted]');
|
||||
expect(result['content-type']).toBe('application/json');
|
||||
});
|
||||
|
||||
it('handles case-insensitive extra header names', () => {
|
||||
const headers = new Headers({ 'X-Custom-Token': 'token-value' });
|
||||
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
|
||||
expect(result['x-custom-token']).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
Loading…
Reference in New Issue