diff --git a/.devops/cann.Dockerfile b/.devops/cann.Dockerfile index 83182c9700..db221b0b81 100644 --- a/.devops/cann.Dockerfile +++ b/.devops/cann.Dockerfile @@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"] # ENTRYPOINT ["/app/llama-server"] ### Target: light -# Lightweight image containing only llama-cli +# Lightweight image containing only llama-cli and llama-completion # ============================================================================== FROM base AS light diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile index ef43d78cd2..6581187f32 100644 --- a/.devops/llama-cli-cann.Dockerfile +++ b/.devops/llama-cli-cann.Dockerfile @@ -23,11 +23,12 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH RUN echo "Building with static libs" && \ source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \ cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_TESTS=OFF && \ - cmake --build build --config Release --target llama-cli + cmake --build build --config Release --target llama-cli && \ + cmake --build build --config Release --target llama-completion # TODO: use image with NNRT FROM ascendai/cann:$ASCEND_VERSION AS runtime -COPY --from=build /app/build/bin/llama-cli /llama-cli +COPY --from=build /app/build/bin/llama-cli /app/build/bin/llama-completion / ENV LC_ALL=C.utf8 diff --git a/.devops/llama-cpp-cuda.srpm.spec b/.devops/llama-cpp-cuda.srpm.spec index 3bbf4a4def..4d42a906b1 100644 --- a/.devops/llama-cpp-cuda.srpm.spec +++ b/.devops/llama-cpp-cuda.srpm.spec @@ -37,6 +37,7 @@ make -j GGML_CUDA=1 %install mkdir -p %{buildroot}%{_bindir}/ cp -p llama-cli %{buildroot}%{_bindir}/llama-cuda-cli +cp -p llama-completion %{buildroot}%{_bindir}/llama-cuda-completion cp -p llama-server %{buildroot}%{_bindir}/llama-cuda-server cp -p llama-simple %{buildroot}%{_bindir}/llama-cuda-simple @@ -68,6 +69,7 @@ rm -rf %{_builddir}/* %files %{_bindir}/llama-cuda-cli +%{_bindir}/llama-cuda-completion %{_bindir}/llama-cuda-server %{_bindir}/llama-cuda-simple /usr/lib/systemd/system/llamacuda.service diff --git a/.devops/llama-cpp.srpm.spec b/.devops/llama-cpp.srpm.spec index 45902dcf89..0a4f43058d 100644 --- a/.devops/llama-cpp.srpm.spec +++ b/.devops/llama-cpp.srpm.spec @@ -39,6 +39,7 @@ make -j %install mkdir -p %{buildroot}%{_bindir}/ cp -p llama-cli %{buildroot}%{_bindir}/llama-cli +cp -p llama-completion %{buildroot}%{_bindir}/llama-completion cp -p llama-server %{buildroot}%{_bindir}/llama-server cp -p llama-simple %{buildroot}%{_bindir}/llama-simple @@ -70,6 +71,7 @@ rm -rf %{_builddir}/* %files %{_bindir}/llama-cli +%{_bindir}/llama-completion %{_bindir}/llama-server %{_bindir}/llama-simple /usr/lib/systemd/system/llama.service diff --git a/SECURITY.md b/SECURITY.md index 9c86ae91b5..ae496f4e3d 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -68,3 +68,6 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/ Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report. A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. + +> [!IMPORTANT] +> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 74a7b6a46d..1bcba9cd86 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -4,9 +4,14 @@ using json = nlohmann::json; -static std::string_view trim_trailing_space(std::string_view sv) { +static std::string_view trim_trailing_space(std::string_view sv, int max = -1) { + int count = 0; while (!sv.empty() && std::isspace(static_cast(sv.back()))) { + if (max != -1 && count <= max) { + break; + } sv.remove_suffix(1); + count++; } return sv; } @@ -93,7 +98,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { if (is_arg_string && current_tool) { // Serialize to JSON, but exclude the end quote - std::string dumped = json(node.text).dump(); + std::string dumped = json(trim_trailing_space(node.text)).dump(); current_tool->arguments += dumped.substr(0, dumped.size() - 1); needs_closing_quote = true; } @@ -101,6 +106,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { if (is_arg_close && current_tool) { if (needs_closing_quote) { current_tool->arguments += "\""; + needs_closing_quote = false; } } @@ -109,6 +115,10 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { } if (is_tool_close && current_tool) { + if (needs_closing_quote) { + current_tool->arguments += "\""; + needs_closing_quote = false; + } current_tool->arguments += "}"; } } diff --git a/common/chat.cpp b/common/chat.cpp index c371edaa5a..0a426f4478 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -711,6 +711,25 @@ static void foreach_function(const json & tools, const std::function & fn) { + if (!function.contains("parameters") || !function.at("parameters").is_object()) { + return; + } + const auto & params = function.at("parameters"); + if (!params.contains("properties") || !params.at("properties").is_object()) { + return; + } + const auto & props = params.at("properties"); + std::set required; + if (params.contains("required") && params.at("required").is_array()) { + params.at("required").get_to(required); + } + for (const auto & [name, prop] : props.items()) { + bool is_required = (required.find(name) != required.end()); + fn(name, prop, is_required); + } +} + static std::string apply( const common_chat_template & tmpl, const struct templates_params & inputs, @@ -1409,6 +1428,123 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_ return data; } +static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + data.prompt = apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED; + + // Handle thinking tags appropriately based on inputs.enable_thinking + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + data.preserved_tokens = { + "", + "", + "", + "", + }; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = true; + + auto parser = build_chat_peg_constructed_parser([&](auto & p) { + auto reasoning = p.eps(); + if (inputs.enable_thinking && extract_reasoning) { + auto reasoning_content = p.reasoning(p.until("")) + ("" | p.end()); + if (data.thinking_forced_open) { + reasoning = reasoning_content; + } + } + + // Response format parser + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { + return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema)); + } + + // Tool call parser + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + + auto schema_info = common_schema_info(); + schema_info.resolve_refs(parameters); + + auto tool_open = "\n"; + auto tool_close = p.literal("\n"); + auto args = p.sequence(); + auto arg_string = p.rule("xml-arg-string", p.until_one_of({ + "\n", + "\n" + })); + + foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) { + auto rule_name = "tool-" + name + "-arg-" + param_name; + + auto arg_open = "\n"; + auto arg_close = p.literal("\n"); + auto arg_value = p.eps(); + + if (schema_info.resolves_to_string(param_schema)) { + arg_value = p.tool_arg_string_value(arg_string) + "\n"; + } else { + arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema)); + } + + // Model may or my not close with + auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close))); + args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1); + }); + + tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close)); + }); + + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + auto tool_call = p.rule("tool-call", "\n" + tool_choice + "" + p.space()); + auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls)); + + return reasoning << p.content(p.until("")) << tool_calls; + } + + // Content only parser + include_grammar = false; + return reasoning << p.content(p.rest()); + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""} + }; + } + + return data; +} + + static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2534,6 +2670,10 @@ static common_chat_params common_chat_templates_apply_jinja( src.find("") != std::string::npos && src.find("") != std::string::npos) { + return common_chat_params_init_nemotron_v3(tmpl, params); + } return common_chat_params_init_qwen3_coder_xml(tmpl, params); } diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index c3b4e5d9dc..2f67c74d79 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) { std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); } -class SchemaConverter { +class common_schema_converter { private: + friend class common_schema_info; friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; @@ -729,7 +730,7 @@ private: } public: - SchemaConverter( + common_schema_converter( const std::function & fetch_json, bool dotall) : _fetch_json(fetch_json), _dotall(dotall) @@ -990,6 +991,134 @@ public: } }; +// common_schema_info implementation (pimpl) + +common_schema_info::common_schema_info() + : impl_(std::make_unique( + [](const std::string &) { return json(); }, + false)) {} + +common_schema_info::~common_schema_info() = default; + +common_schema_info::common_schema_info(common_schema_info &&) noexcept = default; +common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default; + +void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) { + impl_->resolve_refs(schema, ""); +} + +// Determines if a JSON schema can resolve to a string type through any path. +// Some models emit raw string values rather than JSON-encoded strings for string parameters. +// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns +// true, allowing callers to handle the value as a raw string for simplicity. +bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) { + std::unordered_set visited_refs; + + std::function check = [&](const json & s) -> bool { + if (!s.is_object()) { + return false; + } + + // Handle $ref + if (s.contains("$ref")) { + const std::string & ref = s["$ref"]; + if (visited_refs.find(ref) != visited_refs.end()) { + // Circular reference, assume not a string to be safe + return false; + } + visited_refs.insert(ref); + auto it = impl_->_refs.find(ref); + if (it != impl_->_refs.end()) { + return check(it->second); + } + return false; + } + + // Check type field + if (s.contains("type")) { + const json & schema_type = s["type"]; + if (schema_type.is_string()) { + if (schema_type == "string") { + return true; + } + } else if (schema_type.is_array()) { + // Type can be an array like ["string", "null"] + for (const auto & t : schema_type) { + if (t == "string") { + return true; + } + } + } + } + + // Check oneOf/anyOf - if any alternative can be a string + if (s.contains("oneOf")) { + for (const auto & alt : s["oneOf"]) { + if (check(alt)) { + return true; + } + } + } + if (s.contains("anyOf")) { + for (const auto & alt : s["anyOf"]) { + if (check(alt)) { + return true; + } + } + } + + // Check allOf - all components must be compatible with string type + if (s.contains("allOf")) { + bool all_string = true; + for (const auto & component : s["allOf"]) { + if (!check(component)) { + all_string = false; + break; + } + } + if (all_string) { + return true; + } + } + + // Check const - if the constant value is a string + if (s.contains("const")) { + if (s["const"].is_string()) { + return true; + } + } + + // Check enum - if any enum value is a string + if (s.contains("enum")) { + for (const auto & val : s["enum"]) { + if (val.is_string()) { + return true; + } + } + } + + // String-specific keywords imply string type + if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) { + return true; + } + + // Check format - many formats imply string + if (s.contains("format")) { + const std::string & fmt = s["format"]; + if (fmt == "date" || fmt == "time" || fmt == "date-time" || + fmt == "uri" || fmt == "email" || fmt == "hostname" || + fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" || + fmt.find("uuid") == 0) { + return true; + } + } + + return false; + }; + + return check(schema); +} + std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { #ifdef LLAMA_USE_LLGUIDANCE if (!force_gbnf) { @@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { } std::string build_grammar(const std::function & cb, const common_grammar_options & options) { - SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall); + common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall); common_grammar_builder builder { /* .add_rule = */ [&](const std::string & name, const std::string & rule) { return converter._add_rule(name, rule); diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index c89ab7f997..240d642311 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -3,11 +3,31 @@ #include #include +#include #include std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, bool force_gbnf = false); +class common_schema_converter; + +// Probes a JSON schema to extract information about its structure and type constraints. +class common_schema_info { + std::unique_ptr impl_; + + public: + common_schema_info(); + ~common_schema_info(); + + common_schema_info(const common_schema_info &) = delete; + common_schema_info & operator=(const common_schema_info &) = delete; + common_schema_info(common_schema_info &&) noexcept; + common_schema_info & operator=(common_schema_info &&) noexcept; + + void resolve_refs(nlohmann::ordered_json & schema); + bool resolves_to_string(const nlohmann::ordered_json & schema); +}; + struct common_grammar_builder { std::function add_rule; std::function add_schema; diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index dec99e1820..f2fc84500f 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -425,7 +425,7 @@ struct parser_executor { if (result.need_more_input()) { // Propagate - need to know what child would match before negating - return result; + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos); } // Child failed, so negation succeeds diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f09eab54ac..a5d74eef91 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -865,6 +865,14 @@ class TextModel(ModelBase): logger.warning(f"Unknown RoPE type: {rope_type}") logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}") + if "mrope_section" in self.rope_parameters: + mrope_section = self.rope_parameters["mrope_section"] + # Pad to 4 dimensions [time, height, width, extra] + while len(mrope_section) < 4: + mrope_section.append(0) + self.gguf_writer.add_rope_dimension_sections(mrope_section[:4]) + logger.info(f"gguf: mrope sections: {mrope_section[:4]}") + if (rope_theta := rope_params.get("rope_theta")) is not None: self.gguf_writer.add_rope_freq_base(rope_theta) logger.info(f"gguf: rope theta = {rope_theta}") @@ -3742,9 +3750,6 @@ class Qwen2VLModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - mrope_section = self.hparams["rope_scaling"]["mrope_section"] - mrope_section += [0] * max(0, 4 - len(mrope_section)) - self.gguf_writer.add_rope_dimension_sections(mrope_section) def set_vocab(self): try: @@ -4380,6 +4385,30 @@ class Qwen3VLVisionModel(MmprojModel): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration") +class Glm4VVisionModel(Qwen3VLVisionModel): + def set_gguf_parameters(self): + MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters + assert self.hparams_vision is not None + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V) + + hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower() + if hidden_act == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + elif hidden_act == "silu": + self.gguf_writer.add_vision_use_silu(True) + + rms_norm_eps = self.hparams_vision.get("rms_norm_eps", 1e-5) + self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): + name = name.replace("model.visual.", "visual.") + if name.startswith("visual.merger."): + return [(self.map_tensor_name(name), data_torch)] + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Qwen3VLForConditionalGeneration") class Qwen3VLTextModel(Qwen3Model): model_arch = gguf.MODEL_ARCH.QWEN3VL @@ -4388,20 +4417,6 @@ class Qwen3VLTextModel(Qwen3Model): super().set_gguf_parameters() # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL - text_config = self.hparams.get("text_config", {}) - # rope_scaling is deprecated in V5, use rope_parameters instead - rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {} - - if rope_scaling.get("mrope_section"): - # mrope_section contains [time, height, width] dimensions - mrope_section = rope_scaling["mrope_section"] - # Pad to 4 dimensions [time, height, width, extra] - while len(mrope_section) < 4: - mrope_section.append(0) - self.gguf_writer.add_rope_dimension_sections(mrope_section[:4]) - - logger.info(f"MRoPE sections: {mrope_section[:4]}") - vision_config = self.hparams.get("vision_config", {}) deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) @@ -4420,22 +4435,6 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): def set_gguf_parameters(self): super().set_gguf_parameters() - - # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL - text_config = self.hparams.get("text_config", {}) - # rope_scaling is deprecated in V5, use rope_parameters instead - rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {} - - if rope_scaling.get("mrope_section"): - # mrope_section contains [time, height, width] dimensions - mrope_section = rope_scaling["mrope_section"] - # Pad to 4 dimensions [time, height, width, extra] - while len(mrope_section) < 4: - mrope_section.append(0) - self.gguf_writer.add_rope_dimension_sections(mrope_section[:4]) - - logger.info(f"MRoPE sections: {mrope_section[:4]}") - vision_config = self.hparams.get("vision_config", {}) deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) @@ -7885,6 +7884,15 @@ class JaisModel(TextModel): @ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration") class Glm4Model(TextModel): model_arch = gguf.MODEL_ARCH.GLM4 + use_mrope = False + partial_rotary_factor = 0.5 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 0.5) + if "mrope_section" in self.rope_parameters: + self.use_mrope = True + logger.info("Q/K weight will need to be permuted for M-RoPE") def set_vocab(self): from transformers import AutoTokenizer @@ -7906,17 +7914,49 @@ class Glm4Model(TextModel): super().set_gguf_parameters() if (rope_dim := self.hparams.get("head_dim")) is None: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.partial_rotary_factor)) + + @staticmethod + def normal_to_neox(weights: Tensor, n_head: int, n_head_kv: int, head_dim: int, partial_rotary_factor: float) -> Tensor: + orig_shape = weights.shape + if len(orig_shape) == 1: + weights = weights.unsqueeze(1) # [out_dim, 1] + if len(weights.shape) != 2: + raise ValueError("Only 1D and 2D tensors are supported.") + n_effective_heads = weights.shape[0] // head_dim + if n_head_kv is not None and n_effective_heads != n_head: + if n_effective_heads != n_head_kv: + raise AssertionError(f"Mismatch in effective heads: computed {n_effective_heads}, expected {n_head} or {n_head_kv}") + rotary_dim = int(head_dim * partial_rotary_factor) + if rotary_dim % 2 != 0: + raise ValueError("rotary_dim must be even.") + reshaped = weights.reshape(n_effective_heads, head_dim, -1) + rot_part = reshaped[:, :rotary_dim, :] + non_rot_part = reshaped[:, rotary_dim:, :] + permuted_rot = torch.cat((rot_part[:, ::2, :], rot_part[:, 1::2, :]), dim=1) + combined = torch.cat((permuted_rot, non_rot_part), dim=1) + result = combined.reshape(weights.shape) + return result if len(orig_shape) != 1 else result.squeeze(1) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.startswith("model.visual."): # ignore visual part of Glm4v return [] elif name.startswith("model.language_model."): name = name.replace("language_model.", "") # for Glm4v + if self.use_mrope: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams["num_key_value_heads"] + n_embd = self.hparams["hidden_size"] + head_dim = n_embd // n_head + # because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_kv_head, head_dim, self.partial_rotary_factor) return super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Glm4MoeForCausalLM") +@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration") class Glm4MoeModel(TextModel): model_arch = gguf.MODEL_ARCH.GLM4_MOE @@ -7983,6 +8023,7 @@ class Glm4MoeModel(TextModel): _experts: list[dict[str, Tensor]] | None = None + # note: unlike GLM4V non-MoE, we don't need to permute Q/K here since GLM4V_MOE uses Neox ordering already def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 02a72a9d51..f44458ed3b 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -103,6 +103,8 @@ SYCL backend supports Intel GPU Family: - Intel Built-in Arc GPU - Intel iGPU in Core CPU (11th Generation Core CPU and newer, refer to [oneAPI supported GPU](https://www.intel.com/content/www/us/en/developer/articles/system-requirements/intel-oneapi-base-toolkit-system-requirements.html#inpage-nav-1-1)). +On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the performance is not optimal, and some GPUs may not support OpenCL nor have any GPGPU capabilities. + #### Verified devices | Intel GPU | Status | Verified Model | diff --git a/docs/docker.md b/docs/docker.md index b9e5015396..a3b263497c 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -7,9 +7,9 @@ ## Images We have three Docker images available for this project: -1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) -2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) -3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) +1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) +2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the `llama-cli` and `llama-completion` executables. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) +3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the `llama-server` executable. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) Additionally, there the following images, similar to the above: @@ -44,13 +44,15 @@ docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --all-in-o On completion, you are ready to play! ```bash -docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run-legacy -m /models/32B/ggml-model-q8_0.gguf -no-cnv -p "Building a mobile app can be done in 15 steps:" -n 512 ``` or with a light image: ```bash -docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 +docker run -v /path/to/models:/models --entrypoint /app/llama-cli ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf +docker run -v /path/to/models:/models --entrypoint /app/llama-completion ghcr.io/ggml-org/llama.cpp:light -m /models/32B/ggml-model-q8_0.gguf -no-cnv -p "Building a mobile app can be done in 15 steps:" -n 512 ``` or with a server image: @@ -59,6 +61,8 @@ or with a server image: docker run -v /path/to/models:/models -p 8080:8080 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8080 --host 0.0.0.0 -n 512 ``` +In the above examples, `--entrypoint /app/llama-cli` is specified for clarity, but you can safely omit it since it's the default entrypoint in the container. + ## Docker With CUDA Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) properly installed on Linux, or is using a GPU enabled cloud, `cuBLAS` should be accessible inside the container. @@ -80,9 +84,9 @@ The defaults are: The resulting images, are essentially the same as the non-CUDA images: -1. `local/llama.cpp:full-cuda`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. -2. `local/llama.cpp:light-cuda`: This image only includes the main executable file. -3. `local/llama.cpp:server-cuda`: This image only includes the server executable file. +1. `local/llama.cpp:full-cuda`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. +2. `local/llama.cpp:light-cuda`: This image only includes the `llama-cli` and `llama-completion` executables. +3. `local/llama.cpp:server-cuda`: This image only includes the `llama-server` executable. ## Usage @@ -114,9 +118,9 @@ The defaults are: The resulting images, are essentially the same as the non-MUSA images: -1. `local/llama.cpp:full-musa`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. -2. `local/llama.cpp:light-musa`: This image only includes the main executable file. -3. `local/llama.cpp:server-musa`: This image only includes the server executable file. +1. `local/llama.cpp:full-musa`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. +2. `local/llama.cpp:light-musa`: This image only includes the `llama-cli` and `llama-completion` executables. +3. `local/llama.cpp:server-musa`: This image only includes the `llama-server` executable. ## Usage diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md index 05d95d588b..8163b306b4 100644 --- a/examples/model-conversion/README.md +++ b/examples/model-conversion/README.md @@ -10,6 +10,13 @@ and in some cases perplexity checked of the quantized model. And finally the model/models need to the ggml-org on Hugging Face. This tool/example tries to help with this process. +> 📝 **Note:** When adding a new model from an existing family, verify the +> previous version passes logits verification first. Existing models can have +> subtle numerical differences that don't affect generation quality but cause +> logits mismatches. Identifying these upfront whether they exist in llama.cpp, +> the conversion script, or in an upstream implementation, can save significant +> debugging time. + ### Overview The idea is that the makefile targets and scripts here can be used in the development/conversion process assisting with things like: diff --git a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh index c48af3075c..984d03e95d 100755 --- a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh +++ b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh @@ -34,8 +34,11 @@ done MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}" MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}" +CONVERTED_MODEL_PATH="${CONVERTED_EMBEDDING_PATH:-"$CONVERTED_EMBEDDING_MODEL"}" +CONVERTED_MODEL_NAME="${CONVERTED_MODEL_NAME:-$(basename "$CONVERTED_MODEL_PATH" .gguf)}" + if [ -t 0 ]; then - CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin" + CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin" else # Process piped JSON data and convert to binary (matching logits.cpp format) TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index aa005d49b2..d08480d951 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -650,6 +650,7 @@ class MODEL_TENSOR(IntEnum): V_MMPROJ_PEG = auto() V_ENC_EMBD_CLS = auto() V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_NORM = auto() V_ENC_EMBD_POS = auto() V_ENC_INPUT_NORM = auto() V_ENC_ATTN_QKV = auto() @@ -668,6 +669,7 @@ class MODEL_TENSOR(IntEnum): V_LAYER_SCALE_2 = auto() V_PRE_NORM = auto() V_POST_NORM = auto() + V_MM_POST_NORM = auto() V_MM_INP_NORM = auto() V_MM_INP_PROJ = auto() # gemma3 V_MM_SOFT_EMB_NORM = auto() # gemma3 @@ -1040,6 +1042,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", + MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd", MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv", MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", @@ -1058,6 +1061,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln", + MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", @@ -1134,6 +1138,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_MMPROJ_PEG, MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_NORM, MODEL_TENSOR.V_ENC_EMBD_POS, MODEL_TENSOR.V_ENC_EMBD_IMGNL, MODEL_TENSOR.V_ENC_EMBD_VSEP, @@ -1154,6 +1159,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_LAYER_SCALE_2, MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, + MODEL_TENSOR.V_MM_POST_NORM, MODEL_TENSOR.V_MM_INP_PROJ, MODEL_TENSOR.V_MM_INP_NORM, MODEL_TENSOR.V_MM_SOFT_EMB_NORM, @@ -3451,6 +3457,7 @@ class VisionProjectorType: COGVLM = "cogvlm" JANUS_PRO = "janus_pro" DEEPSEEKOCR = "deepseekocr" + GLM4V = "glm4v" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 0ec28034a0..7782fd439d 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1213,6 +1213,7 @@ class TensorNameMap: "model.connector.modality_projection.proj", # SmolVLM "model.vision.linear_proj.linear_proj", # cogvlm "model.projector.layers", # Deepseek-OCR + "visual.merger.proj", # glm4v ), MODEL_TENSOR.V_MMPROJ_MLP: ( @@ -1248,6 +1249,10 @@ class TensorNameMap: "model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP ), + MODEL_TENSOR.V_ENC_EMBD_NORM: ( + "visual.post_conv_layernorm", # glm4v + ), + MODEL_TENSOR.V_ENC_EMBD_POS: ( "vision_tower.vision_model.embeddings.position_embedding", "model.vision_tower.embeddings.position_embeddings", # Intern-S1 @@ -1257,6 +1262,7 @@ class TensorNameMap: "vision_tower.patch_embed.pos_emb", # kimi-vl "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm + "visual.embeddings.position_embedding", # glm4v ), MODEL_TENSOR.V_ENC_EMBD_IMGNL: ( @@ -1430,6 +1436,11 @@ class TensorNameMap: "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl + "visual.post_layernorm", # glm4v + ), + + MODEL_TENSOR.V_MM_POST_NORM: ( + "visual.merger.post_projection_norm", # glm4v ), MODEL_TENSOR.V_MM_INP_PROJ: ( @@ -1499,6 +1510,7 @@ class TensorNameMap: MODEL_TENSOR.V_MM_PATCH_MERGER: ( "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf "patch_merger.merging_layer", # mistral + "visual.downsample", # glm4v ), MODEL_TENSOR.V_DS_NORM: ( @@ -1571,14 +1583,17 @@ class TensorNameMap: MODEL_TENSOR.V_MM_UP: ( "model.vision.linear_proj.dense_h_to_4h", # cogvlm + "visual.merger.up_proj", # glm4v ), MODEL_TENSOR.V_MM_DOWN: ( "model.vision.linear_proj.dense_4h_to_h", # cogvlm + "visual.merger.down_proj", # glm4v ), MODEL_TENSOR.V_MM_GATE: ( "model.vision.linear_proj.gate_proj", # cogvlm + "visual.merger.gate_proj", # glm4v ), MODEL_TENSOR.V_TOK_BOI: ( diff --git a/models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja b/models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja new file mode 100644 index 0000000000..a01e0861c6 --- /dev/null +++ b/models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja @@ -0,0 +1,204 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} +{%- set enable_thinking = enable_thinking if enable_thinking is defined else True %} +{%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %} + +{%- set ns = namespace(last_user_idx = -1) %} +{%- set loop_messages = messages %} +{%- for m in loop_messages %} + {%- if m["role"] == "user" %} + {%- set ns.last_user_idx = loop.index0 %} + {%- endif %} +{%- endfor %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} +{# Recompute last_user_idx relative to loop_messages after handling system #} +{%- set ns = namespace(last_user_idx = -1) %} +{%- for m in loop_messages %} + {%- if m["role"] == "user" %} + {%- set ns.last_user_idx = loop.index0 %} + {%- endif %} +{%- endfor %} +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\n" }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {%- if system_message is defined and system_message | length > 0 %} + {{- "\n\n" }} + {%- endif %} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- if param_fields.enum is defined %} + {{- '\n' ~ (param_fields.enum | tojson | safe) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description', 'enum'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties', 'required'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {%- if tool.parameters is defined and tool.parameters.required is defined %} + {{- '\n' ~ (tool.parameters.required | tojson | safe) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} + + +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- for message in loop_messages %} + {%- if message.role == "assistant" %} + {# Add reasoning content in to content field for unified processing below. #} + {%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %} + {%- set content = "\n" ~ message.reasoning_content ~ "\n\n" ~ (message.content | default('', true)) %} + {%- else %} + {%- set content = message.content | default('', true) %} + {%- if content is string -%} + {# Allow downstream logic to to take care of broken thought, only handle coherent reasoning here. #} + {%- if '' not in content and '' not in content -%} + {%- set content = "" ~ content -%} + {%- endif -%} + {%- else -%} + {%- set content = content -%} + {%- endif -%} + {%- endif %} + {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {# Assistant message has tool calls. #} + {{- '<|im_start|>assistant\n' }} + {%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %} + {%- if content is string and content | trim | length > 0 %} + {%- if include_content %} + {{- (content | trim) ~ '\n' -}} + {%- else %} + {%- set c = (content | string) %} + {%- if '' in c %} + {# Keep only content after the last closing think. Also generation prompt causes this. #} + {%- set c = c.split('')[-1] %} + {%- elif '' in c %} + {# If was opened but never closed, drop the trailing think segment #} + {%- set c = c.split('')[0] %} + {%- endif %} + {%- set c = "" ~ c | trim %} + {%- if c | length > 0 %} + {{- c ~ '\n' -}} + {%- endif %} + {%- endif %} + {%- else %} + {{- "" -}} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n' -}} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' -}} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value ~ '\n\n' -}} + {%- endfor %} + {%- endif %} + {{- '\n\n' -}} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- else %} + {# Assistant message doesn't have tool calls. #} + {%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %} + {{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }} + {%- else %} + {%- set c = (content | default('', true) | string) %} + {%- if '' in c and '' in c %} + {%- set c = "" ~ c.split('')[-1] %} + {%- endif %} + {%- set c = c | trim %} + {%- if c | length > 0 %} + {{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>assistant\n<|im_end|>\n' }} + {%- endif %} + {%- endif %} + {%- endif %} + {%- elif message.role == "user" or message.role == "system" %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- set content = message.content | string %} + {{- content }} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {%- if enable_thinking %} + {{- '<|im_start|>assistant\n\n' }} + {%- else %} + {{- '<|im_start|>assistant\n' }} + {%- endif %} +{%- endif %} diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 83d6d6ee3c..fe1fa4341d 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -231,3 +231,7 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama return false; } + +bool llama_hparams::use_mrope() const { + return rope_sections[0] > 0 && rope_sections[1] > 0; +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index cecb476e91..f6e95b5d2a 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -270,6 +270,8 @@ struct llama_hparams { // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + + bool use_mrope() const; }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1143c7a606..2090725e27 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1694,7 +1694,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GLM4: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); switch (hparams.n_layer) { case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; @@ -1703,8 +1704,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GLM4_MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); // MoE parameters ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); @@ -7830,7 +7832,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: - case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: @@ -7892,7 +7893,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: - case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: case LLM_ARCH_APERTUS: @@ -7909,6 +7909,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: return LLAMA_ROPE_TYPE_IMROPE; + case LLM_ARCH_GLM4: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NORM; + case LLM_ARCH_GLM4_MOE: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index 33ee707046..003f70f739 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -5,11 +5,20 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + bool use_mrope = hparams.use_mrope(); + if (ubatch.embd && !use_mrope) { + // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results + GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); + } + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -60,17 +69,25 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); } - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + if (use_mrope) { + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else { + // Normal RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); diff --git a/src/models/glm4.cpp b/src/models/glm4.cpp index f789b28248..204aa3932a 100644 --- a/src/models/glm4.cpp +++ b/src/models/glm4.cpp @@ -8,11 +8,20 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + bool use_mrope = hparams.use_mrope(); + if (ubatch.embd && !use_mrope) { + // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results + GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); + } + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -63,11 +72,25 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + if (use_mrope) { + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else { + // Normal RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 007929f517..02af5251cc 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -3588,6 +3588,163 @@ static void test_template_output_peg_parsers() { t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; }); } + + { + // NVIDIA Nemotron-3 Nano + auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja"); + + // Test basic message + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "Hello, world!\nWhat's up?"; + t.expect = message_assist; + }); + + // Test basic message and reasoning with reasoning_format = none + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "I'm\nthinking\n\nHello, world!\nWhat's up?"; + t.expect.content = "I'm\nthinking\n\nHello, world!\nWhat's up?"; + }); + + // Test basic message and reasoning with reasoning_format = auto + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "I'm\nthinking\n\nHello, world!\nWhat's up?"; + t.params.enable_thinking = true; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + + t.expect = message_assist_thoughts; + }); + + // Test tool call + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + ""; + t.params.enable_thinking = false; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.tools = {special_function_tool}; + + t.expect = message_assist_call; + }); + + // Test tool call with reasoning + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "I'm\nthinking\n\n" + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + ""; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.tools = {special_function_tool}; + + t.expect = message_assist_call_thoughts; + }); + + // Test parallel tool calls + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "2\n" + "\n" + "\n" + ""; + t.params.enable_thinking = false; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.parallel_tool_calls = true; + t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; + + t.expect.tool_calls = {{ + /* .name = */ "special_function", + /* .arguments = */ R"({"arg1": 1})", + /* .id = */ {}, + }, { + /* .name = */ "special_function_with_opt", + /* .arguments = */ R"({"arg1": 1, "arg2": 2})", + /* .id = */ {}, + }}; + }); + + // Test tool call with string parameter + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + ""; + t.params.enable_thinking = false; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.tools = {python_tool}; + + t.expect.tool_calls = {{ + /* .name = */ "python", + /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", + /* .id = */ {}, + }}; + }); + + // Test tool call with string parameter and no closing tag + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + ""; + t.params.enable_thinking = false; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.tools = {python_tool}; + + t.expect.tool_calls = {{ + /* .name = */ "python", + /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", + /* .id = */ {}, + }}; + }); + + // Test response format + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "I need to output the invoice details in JSON\n" + "\n" + R"({"amount": 123.45, "date": "2025-12-03"})"; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.json_schema = invoice_schema; + + t.expect.reasoning_content = "I need to output the invoice details in JSON"; + t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})"; + }); + } + } static void test_msg_diffs_compute() { diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 6a4bd8fb4d..a8e9ff33a4 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1367,10 +1367,85 @@ static void test_all(const std::string & lang, std::function #include +#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS) + struct clip_graph { const clip_model & model; const clip_hparams & hparams; @@ -49,7 +51,7 @@ struct clip_graph { void cb(ggml_tensor * cur0, const char * name, int il) const; // siglip2 naflex - ggml_tensor * resize_position_embeddings(); + ggml_tensor * resize_position_embeddings(uint32_t interpolation_mode = DEFAULT_INTERPOLATION_MODE); // build vision transformer (ViT) cgraph // this function should cover most of the models diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index ecf378524e..9faa01b1be 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -71,6 +71,7 @@ #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" +#define TN_NORM_EMBD "v.norm_embd.%s" #define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s" #define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" @@ -89,6 +90,10 @@ #define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_POST "%s.post_ln.%s" #define TN_LLAVA_PROJ "mm.%d.%s" +#define TN_MM_UP "mm.up.%s" +#define TN_MM_GATE "mm.gate.%s" +#define TN_MM_DOWN "mm.down.%s" +#define TN_MM_POST_NORM "mm.post_norm.%s" #define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" @@ -99,7 +104,7 @@ #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr -#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 +#define TN_MM_PATCH_MERGER "mm.patch_merger.%s" // mistral small 3.1, glm4v #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) @@ -184,6 +189,7 @@ enum projector_type { PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_DEEPSEEKOCR, + PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_UNKNOWN, }; @@ -212,6 +218,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"}, + { PROJECTOR_TYPE_GLM4V, "glm4v"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { @@ -695,6 +702,8 @@ static void save_tensor_to_file(const struct ggml_tensor * tensor, const uint8_t fprintf(stderr, "Tensor saved successfully\n"); } +void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value); + // // API used internally with mtmd // diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index a927fa53e0..24d12534f9 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -182,6 +182,8 @@ struct clip_model { ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) ggml_tensor * patch_bias = nullptr; ggml_tensor * position_embeddings = nullptr; + ggml_tensor * norm_embd_w = nullptr; + ggml_tensor * norm_embd_b = nullptr; ggml_tensor * pre_ln_w = nullptr; ggml_tensor * pre_ln_b = nullptr; @@ -197,6 +199,14 @@ struct clip_model { ggml_tensor * fc_b; ggml_tensor * mm_fc_w; ggml_tensor * mm_fc_b; + ggml_tensor * mm_ffn_up_w = nullptr; + ggml_tensor * mm_ffn_up_b = nullptr; + ggml_tensor * mm_ffn_gate_w = nullptr; + ggml_tensor * mm_ffn_gate_b = nullptr; + ggml_tensor * mm_ffn_down_w = nullptr; + ggml_tensor * mm_ffn_down_b = nullptr; + ggml_tensor * mm_post_norm_w = nullptr; + ggml_tensor * mm_post_norm_b = nullptr; // LLaVA projection ggml_tensor * mm_input_norm_w = nullptr; @@ -280,9 +290,10 @@ struct clip_model { ggml_tensor * mm_input_proj_w = nullptr; ggml_tensor * mm_soft_emb_norm_w = nullptr; - // pixtral + // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; ggml_tensor * mm_patch_merger_w = nullptr; + ggml_tensor * mm_patch_merger_b = nullptr; // ultravox / whisper encoder ggml_tensor * conv1d_1_w = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 5e4daa261b..eb5f66d6cb 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -265,11 +265,11 @@ void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const { } // siglip2 naflex -ggml_tensor * clip_graph::resize_position_embeddings() { +ggml_tensor * clip_graph::resize_position_embeddings(uint32_t interpolation_mode) { ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; + const uint32_t mode = interpolation_mode; const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); GGML_ASSERT(pos_embd); @@ -453,11 +453,9 @@ ggml_tensor * clip_graph::build_vit( return inpL; } - - - // build the input after conv2d (inp_raw --> patches) - // returns tensor with shape [n_embd, n_patches] - ggml_tensor * clip_graph::build_inp() { +// build the input after conv2d (inp_raw --> patches) +// returns tensor with shape [n_embd, n_patches] +ggml_tensor * clip_graph::build_inp() { ggml_tensor * inp_raw = build_inp_raw(); ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); @@ -488,19 +486,14 @@ ggml_tensor * clip_graph::build_norm( ? ggml_rms_norm(ctx0, cur, norm_eps) : ggml_norm(ctx0, cur, norm_eps); - if (mw || mb) { - cb(cur, "norm", il); - } - if (mw) { cur = ggml_mul(ctx0, cur, mw); - if (mb) { - cb(cur, "norm_w", il); - } + cb(cur, "norm_w", il); } if (mb) { cur = ggml_add(ctx0, cur, mb); + cb(cur, "norm_b", il); } return cur; @@ -846,6 +839,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_GLM4V: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1159,6 +1156,14 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_GLM4V: + { + hparams.rope_theta = 10000.0f; + hparams.n_merge = 2; // default value for GLM4-V + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + hparams.set_limit_image_tokens(8, 4096); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -1297,6 +1302,9 @@ struct clip_model_loader { model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); + model.norm_embd_w = get_tensor(string_format(TN_NORM_EMBD, "weight"), false); + model.norm_embd_b = get_tensor(string_format(TN_NORM_EMBD, "bias"), false); + model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false); // layers @@ -1485,6 +1493,20 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_GLM4V: + { + model.projection = get_tensor(TN_MM_PROJECTOR); + model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight")); + model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias"), false); + model.mm_ffn_gate_w = get_tensor(string_format(TN_MM_GATE, "weight")); + model.mm_ffn_gate_b = get_tensor(string_format(TN_MM_GATE, "bias"), false); + model.mm_ffn_down_w = get_tensor(string_format(TN_MM_DOWN, "weight")); + model.mm_ffn_down_b = get_tensor(string_format(TN_MM_DOWN, "bias"), false); + model.mm_post_norm_w = get_tensor(string_format(TN_MM_POST_NORM, "weight")); + model.mm_post_norm_b = get_tensor(string_format(TN_MM_POST_NORM, "bias"), false); + model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight")); + model.mm_patch_merger_b = get_tensor(string_format(TN_MM_PATCH_MERGER, "bias")); + } break; case PROJECTOR_TYPE_GEMMA3: { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); @@ -1513,8 +1535,8 @@ struct clip_model_loader { // [IMG_BREAK] token embedding model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); // for mistral small 3.1 - model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); - model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false); } break; case PROJECTOR_TYPE_LIGHTONOCR: { @@ -1522,8 +1544,8 @@ struct clip_model_loader { model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); - model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); - model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false); } break; case PROJECTOR_TYPE_ULTRAVOX: { @@ -1924,6 +1946,8 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params if (ctx_params.warmup) { loader.warmup(*ctx_vision); } + + // clip_debug_encode(ctx_vision, 24*14, 24*14, 0.5f); } if (loader.has_audio) { @@ -2942,6 +2966,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); clip_image_u8 resized; @@ -3268,16 +3293,30 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { - return img->nx / (params.patch_size * 2); + const auto & proj = ctx->proj_type(); + switch (proj) { + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: + return (img->nx / params.patch_size) / 2; + default: + break; } return n_total; } int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { - return img->ny / (params.patch_size * 2); + const auto & proj = ctx->proj_type(); + switch (proj) { + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: + return (img->ny / params.patch_size) / 2; + default: + break; } return 1; } @@ -3334,6 +3373,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -3593,6 +3633,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: { const int merge_ratio = hparams.n_merge; const int pw = image_size_width / patch_size; @@ -3862,7 +3903,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // copy the embeddings to the location passed by the user - ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + if (vec != nullptr) { + ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + } return true; } @@ -3912,6 +3955,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_4h_to_h_w->ne[1]; case PROJECTOR_TYPE_DEEPSEEKOCR: return ctx->model.fc_w->ne[1]; + case PROJECTOR_TYPE_GLM4V: + return ctx->model.mm_ffn_down_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } @@ -3928,10 +3973,11 @@ bool clip_is_glm(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE; } -bool clip_is_qwen2vl(const struct clip_ctx * ctx) { +bool clip_is_mrope(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL - || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL; + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL + || ctx->proj_type() == PROJECTOR_TYPE_GLM4V; } bool clip_is_llava(const struct clip_ctx * ctx) { @@ -3996,3 +4042,22 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) { return &ctx->model.hparams; } + +// +// API for debugging +// + +void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) { + clip_image_f32 img; + img.nx = w; + img.ny = h; + img.buf.resize(h * w * 3); + for (int i = 0; i < h * w * 3; i++) { + img.buf[i] = static_cast(fill_value); + } + bool cur_debug_graph = ctx->debug_graph; + ctx->debug_graph = true; + clip_image_encode(ctx, 1, &img, nullptr); + ctx->debug_graph = cur_debug_graph; + GGML_ASSERT(img.buf.empty() && "expected, always stop here"); +} diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 144a3d7b44..379fb83df3 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -104,7 +104,7 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct int clip_is_minicpmv(const struct clip_ctx * ctx); bool clip_is_glm(const struct clip_ctx * ctx); -bool clip_is_qwen2vl(const struct clip_ctx * ctx); +bool clip_is_mrope(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); bool clip_is_deepseekocr(const struct clip_ctx * ctx); diff --git a/tools/mtmd/models/glm4v.cpp b/tools/mtmd/models/glm4v.cpp new file mode 100644 index 0000000000..f39b6922eb --- /dev/null +++ b/tools/mtmd/models/glm4v.cpp @@ -0,0 +1,120 @@ +#include "models.h" + +ggml_cgraph * clip_graph_glm4v::build() { + GGML_ASSERT(model.patch_bias != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + const int batch_size = 1; + + norm_type norm_t = NORM_TYPE_RMS; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + // second conv dimension + { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + // add patch bias + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + + // pos-conv norm + inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1); + + // calculate absolute position embedding and apply + ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC); + learned_pos_embd = ggml_cont_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + learned_pos_embd = ggml_reshape_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); + learned_pos_embd = ggml_cont_3d( + ctx0, learned_pos_embd, + n_embd, n_patches_x * n_patches_y, batch_size); + cb(learned_pos_embd, "learned_pos_embd", -1); + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return ggml_rope_multi( + ctx0, cur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, + 32768, hparams.rope_theta, 1, 0, 1, 32, 1); + }; + + ggml_tensor * cur = build_vit( + inp, n_patches, + norm_t, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + // cb(ggml_sum(ctx0, cur), "vit_out_sum", -1); + + // GLM4V projector + // ref: https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/models/glm4v/modeling_glm4v.py#L116-L130 + + // patch merger (downsample) + { + int n_merge = hparams.n_merge; + GGML_ASSERT(n_merge > 0); + + int n_token_out = n_patches / n_merge / n_merge; + cur = ggml_reshape_4d(ctx0, cur, n_embd, n_merge, n_merge, n_token_out); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); // [n_merge, n_merge, n_embd, n_token_out] + cur = ggml_conv_2d(ctx0, model.mm_patch_merger_w, cur, n_merge, n_merge, 0, 0, 1, 1); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], n_token_out); // [n_embd_out, n_token_out] + + cur = ggml_add(ctx0, cur, model.mm_patch_merger_b); + } + + // FC projector + { + cur = ggml_mul_mat(ctx0, model.projection, cur); + // default LayerNorm (post_projection_norm) + cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = ggml_gelu_erf(ctx0, cur); + cb(cur, "after_fc_proj", -1); + } + + // FFN projector + { + cur = build_ffn(cur, + model.mm_ffn_up_w, model.mm_ffn_up_b, + model.mm_ffn_gate_w, model.mm_ffn_gate_b, + model.mm_ffn_down_w, model.mm_ffn_down_b, + hparams.ffn_op, -1); + cb(cur, "after_ffn_proj", -1); + // cb(ggml_sum(ctx0, cur), "merged_sum", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index bf020516fb..ae592e5ad5 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -61,3 +61,8 @@ struct clip_graph_deepseekocr : clip_graph { clip_graph_deepseekocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; }; + +struct clip_graph_glm4v : clip_graph { + clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d82695da01..036f940cca 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -217,7 +217,7 @@ struct mtmd_context { void init_vision() { GGML_ASSERT(ctx_v != nullptr); - use_mrope = clip_is_qwen2vl(ctx_v); + use_mrope = clip_is_mrope(ctx_v); projector_type proj = clip_get_projector_type(ctx_v); int minicpmv_version = clip_is_minicpmv(ctx_v); @@ -309,6 +309,10 @@ struct mtmd_context { img_beg = "<|image_start|>"; img_end = "<|image_end|>"; + } else if (proj == PROJECTOR_TYPE_GLM4V) { + img_beg = "<|begin_of_image|>"; + img_end = "<|end_of_image|>"; + } } diff --git a/tools/server/README.md b/tools/server/README.md index 073bcd2ccd..ef4990faf1 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1430,7 +1430,7 @@ Model presets allow advanced users to define custom configurations using an `.in llama-server --models-preset ./my-models.ini ``` -Each section in the file defines a new preset. Keys within a section correspond to command-line arguments (without leading dashes). For example, the argument `--n-gpu-layer 123` is written as `n-gpu-layer = 123`. +Each section in the file defines a new preset. Keys within a section correspond to command-line arguments (without leading dashes). For example, the argument `--n-gpu-layers 123` is written as `n-gpu-layers = 123`. Short argument forms (e.g., `c`, `ngl`) and environment variable names (e.g., `LLAMA_ARG_N_GPU_LAYERS`) are also supported as keys. @@ -1445,7 +1445,7 @@ version = 1 ; string value chat-template = chatml ; numeric value -n-gpu-layer = 123 +n-gpu-layers = 123 ; flag value (for certain flags, you need to use the "no-" prefix for negation) jinja = true ; shorthand argument (for example, context size)