diff --git a/common/arg.cpp b/common/arg.cpp index f958b76d07..1ef8d70548 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2105,7 +2105,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "override tensor buffer type", [](common_params & params, const std::string & value) { parse_tensor_buffer_overrides(value, params.tensor_buft_overrides); } - )); + ).set_env("LLAMA_ARG_OVERRIDE_TENSOR")); add_opt(common_arg( {"-otd", "--override-tensor-draft"}, "=,...", "override tensor buffer type for draft model", [](common_params & params, const std::string & value) { @@ -3536,15 +3536,15 @@ void common_params_add_preset_options(std::vector & args) { [](common_params &, const std::string &) { /* unused */ } ).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only()); + args.push_back(common_arg( + {"stop-timeout"}, "SECONDS", + "in server router mode, force-kill model instance after this many seconds of graceful shutdown", + [](common_params &, int) { /* unused */ } + ).set_env(COMMON_ARG_PRESET_STOP_TIMEOUT).set_preset_only()); + // args.push_back(common_arg( // {"pin"}, // "in server router mode, do not unload this model if models_max is exceeded", // [](common_params &) { /* unused */ } // ).set_preset_only()); - - // args.push_back(common_arg( - // {"unload-idle-seconds"}, "SECONDS", - // "in server router mode, unload models idle for more than this many seconds", - // [](common_params &, int) { /* unused */ } - // ).set_preset_only()); } diff --git a/common/arg.h b/common/arg.h index f5111c658f..a1b6a14e67 100644 --- a/common/arg.h +++ b/common/arg.h @@ -10,6 +10,7 @@ // pseudo-env variable to identify preset-only arguments #define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP" +#define COMMON_ARG_PRESET_STOP_TIMEOUT "__PRESET_STOP_TIMEOUT" // // CLI argument parsing diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 16c5acf346..69abb7367d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7362,6 +7362,90 @@ class MiniMaxM2Model(TextModel): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("MiMoV2FlashForCausalLM") +class MimoV2Model(TextModel): + model_arch = gguf.MODEL_ARCH.MIMO2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + assert self.hparams["swa_head_dim"] == self.hparams["head_dim"] + assert self.hparams["swa_num_attention_heads"] == self.hparams["num_attention_heads"] + assert self.hparams["swa_v_head_dim"] == self.hparams["v_head_dim"] + assert self.hparams["topk_method"] == "noaux_tc" + + n_head_kv = self.hparams["num_key_value_heads"] + n_head_kv_swa = self.hparams["swa_num_key_value_heads"] + n_head_kv_arr = [n_head_kv_swa if use_swa == 1 else n_head_kv for use_swa in self.hparams["hybrid_layer_pattern"]] + self.gguf_writer.add_head_count_kv(n_head_kv_arr) + + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"]) + self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"]) + self.gguf_writer.add_value_length(self.hparams["v_head_dim"]) + self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + + rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"]) + self.gguf_writer.add_rope_dimension_count(rope_dim) + + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5)) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch, name, bid): + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + if "attention_sink" in name and not name.endswith(".weight"): + name += ".weight" + + # TODO: mimo v2 does not indicate the number of next-token-prediction layers, therefore we cannot do the same way as GLM4_MOE + if "model.mtp." in name: + return [] + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["gate_proj", "up_proj", "down_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename_to_retrieve]) + del self._experts[bid][ename_to_retrieve] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("PanguEmbeddedForCausalLM") class PanguEmbeddedModel(TextModel): model_arch = gguf.MODEL_ARCH.PANGU_EMBED @@ -8695,6 +8779,11 @@ class NemotronHModel(GraniteHybridModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("LlamaBidirectionalModel") +class LlamaEmbedNemotronModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA_EMBED + + @ModelBase.register("BailingMoeForCausalLM") class BailingMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index f44458ed3b..bcb3ce6743 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -829,7 +829,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 No. We can't support Ollama issue directly, because we aren't familiar with Ollama. - Sugguest reproducing on llama.cpp and report similar issue to llama.cpp. We will surpport it. + Suggest reproducing on llama.cpp and report similar issue to llama.cpp. We will support it. It's same for other projects including llama.cpp SYCL backend. diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 835b53f659..3688abdd58 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2338,19 +2338,19 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. // TODO: acl_yarn_ramp_tensor use rope cache. bool yarn_ramp_tensor_updated = false; - ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); acl_tensor_ptr acl_yarn_ramp_tensor; if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { yarn_ramp_tensor_updated = true; - + if (ctx.rope_cache.yarn_ramp_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); // -rope_yarn_ramp // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // return MIN(1, MAX(0, y)) - 1; - yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void * yarn_ramp_buffer = yarn_ramp_allocator.get(); acl_yarn_ramp_tensor = - ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); @@ -2380,8 +2380,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); + } else { + acl_yarn_ramp_tensor = + ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); } - // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. if (ext_factor != 0) { if (theta_scale_updated || yarn_ramp_tensor_updated) { @@ -2988,32 +2990,156 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get()); } -void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; // stride - int64_t s0 = ((const int32_t *) (dst->op_params))[0]; + int64_t s0 = ((const int32_t*)(dst->op_params))[0]; - acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + + // get base information of input and kernel + int64_t input_len = *(src1->ne); + int64_t dst_len = *(dst->ne); + int64_t kernel_size = *(src0->ne); + + // set the max kernel size for each conv + int64_t max_kernel_size = 255; + + // compute the partition of kernel + int64_t part_num = 1; + part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; int64_t strideVal[1]; - strideVal[0] = s0; - acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); - int64_t paddingVal[] = { 0 }; - acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); - int64_t dilationVal[] = { 1 }; - acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); - int8_t cubeMathType = 0; + strideVal[0] = s0; + acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); + int64_t paddingVal[] = {0}; + acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); + int64_t dilationVal[] = {1}; + acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); + bool transposed = true; + int64_t groups = 1; + int8_t cubeMathType = 0; #ifdef ASCEND_310P cubeMathType = 1; #endif - GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), acl_weight.get(), nullptr, stride.get(), padding.get(), - dilation.get(), true, padding.get(), 1, acl_dst.get(), cubeMathType); + auto weight_type = ggml_cann_type_mapping(src0->type); + auto dst_type = ggml_cann_type_mapping(dst->type); + + // slice the kernel to make each conv available + int64_t slice_dim = -1; + int64_t slice_start = 0; + int64_t slice_end = max_kernel_size; + int64_t slice_step = 1; + int64_t interval = max_kernel_size; + + int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; + int64_t right_pad_len = 0; + + acl_scalar_ptr alpha = nullptr; + float alphaValue = 1.0; + alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); + + // set zero to destination + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + + for(int k = 0; k < part_num; k++){ + + // create part kernel tensor and slice from big kernel + slice_start = max_kernel_size * k; + if(k == part_num - 1){ + slice_end = kernel_size; + interval = kernel_size - max_kernel_size * k; + }else{ + slice_end = max_kernel_size * (k+1); + } + + int64_t part_ne[4]; + for(int i = 0; i < 4; i++) { + part_ne[i] = *(src0->ne + i); + } + part_ne[0] = interval; + + size_t part_nb[4]; + part_nb[0] = sizeof(weight_type); + for (int i = 1; i < 4; i++) { + part_nb[i] = part_nb[i - 1] * part_ne[i - 1]; + } + + ggml_cann_pool_alloc part_kernel_allocator; + part_kernel_allocator.alloc(ctx.pool(), part_nb[3]); + void* part_kernel_buf = part_kernel_allocator.get(); + + acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, + ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL); + + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get()); + + // create the part conv result tensor + int64_t part_dst_ne[4]; + for(int i = 0; i < 4; i++){ + part_dst_ne[i] = *(dst->ne + i); + } + part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1; + + size_t part_dst_nb[4]; + part_dst_nb[0] = sizeof(weight_type); + for (int i = 1; i < 4; i++) { + part_dst_nb[i] = part_dst_nb[i - 1] * part_dst_ne[i - 1]; + } + ggml_cann_pool_alloc part_dst_allocator; + part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]); + void* part_dst_buf = part_dst_allocator.get(); + + acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst), + part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get()); + + // compute part conv transpose 1d + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(), + padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType); + + // compute the position of part result in final result + int64_t global_start = slice_start; + int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); + + left_pad_len = global_start; + right_pad_len = dst_len - global_end; + + std::vector padDataVal = {left_pad_len,right_pad_len}; + acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); + + acl_scalar_ptr pad_value = nullptr; + float pad_valueVal = 0.0; + pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); + + int64_t conv_result_ne[4]; + for(int i = 0; i < 4; i++){ + conv_result_ne[i] = *(dst->ne + i); + } + + size_t conv_result_nb[4]; + conv_result_nb[0] = sizeof(weight_type); + for (int i = 1; i < 4; i++) { + conv_result_nb[i] = conv_result_nb[i - 1] * conv_result_ne[i - 1]; + } + + ggml_cann_pool_alloc conv_result_allocator; + conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]); + void* conv_result_buf = conv_result_allocator.get(); + + acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst), + conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get()); + } } void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 1ebbc769c7..6ec4640289 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -47,6 +47,7 @@ #include #include #include +#include #include #include diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 45c7294e68..e9a21e1b05 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -229,6 +229,60 @@ struct ggml_graph_node_properties { // op ggml_op node_op; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + + /** + * @brief Check if a ggml tensor node matches this property set. + * + * This function compares all relevant fields (address, op type, shape, source inputs, op params) + * to determine whether the current node matches these previously recorded properties. + * + * @param node The current ggml tensor node. + * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. + */ + bool has_matching_properties(ggml_tensor * node) { + if (node->data != this->node_address && node->op != GGML_OP_VIEW) { + return false; + } + + if (node->op != this->node_op) { + return false; + } + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != this->ne[i]) { + return false; + } + if (node->nb[i] != this->nb[i]) { + return false; + } + } + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) { + return false; + } + + for (int d = 0; d < GGML_MAX_DIMS; d++) { + if (node->src[i]->ne[d] != this->src_ne[i][d]) { + return false; + } + if (node->src[i]->nb[d] != this->src_nb[i][d]) { + return false; + } + } + } else { + if (this->src_address[i] != nullptr) { + return false; + } + } + } + + if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { + return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; + } + return true; + } }; struct ggml_cann_graph { @@ -241,6 +295,79 @@ struct ggml_cann_graph { aclmdlRI graph = nullptr; std::vector ggml_graph_properties; + + /** + * @brief Create a new CANN graph from a ggml computation graph. + * + * This function creates a new ggml_cann_graph object and fills its node properties + * (operation type, dimensions, strides, input sources, and operation parameters) + * based on the current ggml computation graph. + * + * Each node in the ggml graph is mapped to a property entry in the new CANN graph: + * - node address + * - operation type + * - shape (ne) and strides (nb) + * - source tensor addresses + * - operation parameters + * + * @param cgraph The current ggml computation graph. + * @return Pointer to the newly created ggml_cann_graph object. + */ + static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) { + ggml_cann_graph * new_graph = new ggml_cann_graph(); + new_graph->ggml_graph_properties.resize(cgraph->n_nodes); + + for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) { + ggml_tensor * node = cgraph->nodes[node_idx]; + auto & prop = new_graph->ggml_graph_properties[node_idx]; + + prop.node_address = node->data; + prop.node_op = node->op; + + std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); + std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); + + for (int src = 0; src < GGML_MAX_SRC; ++src) { + if (node->src[src]) { + prop.src_address[src] = node->src[src]->data; + std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); + std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); + } else { + prop.src_address[src] = nullptr; + std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); + std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); + } + } + + memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); + } + + return new_graph; + } + + /** + * @brief Check whether this CANN graph matches the given ggml computation graph. + * + * This function compares the number of nodes and each node's properties + * (operation type, dimensions, strides, inputs, and operation parameters) + * to determine whether this CANN graph matches the given ggml graph. + * + * @param cgraph The current ggml computation graph. + * @return true if this CANN graph matches the ggml graph; false otherwise. + */ + bool matches_cgraph(ggml_cgraph * cgraph) { + if (this->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) { + return false; + } + + for (int i = 0; i < cgraph->n_nodes; ++i) { + if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) { + return false; + } + } + + return true; + } }; /** @@ -272,15 +399,6 @@ struct ggml_cann_graph_lru_cache { cache_list.push_front(new_node); } - /** - * @brief Move an existing graph to the front of the cache. - * @param node Pointer to the ggml_cann_graph to move. - */ - void move_to_front(ggml_cann_graph * node) { - cache_list.remove(node); - cache_list.push_front(node); - } - /** * @brief Clear all graphs from the cache (also frees memory). */ @@ -295,6 +413,28 @@ struct ggml_cann_graph_lru_cache { * @brief Destructor that clears the cache and frees all cached graphs. */ ~ggml_cann_graph_lru_cache() { clear(); } + + /** + * @brief Find a cached CANN graph that matches the given ggml graph and move it to front. + * + * This function iterates through the cached CANN graphs stored in the LRU cache and + * compares them against the given ggml computation graph. If a matching graph is found, + * it is promoted to the front of the LRU cache and returned. Otherwise, the function + * returns nullptr. + * + * @param cgraph The current ggml computation graph. + * @return true if found; false otherwise. + */ + bool find_and_move_to_front(ggml_cgraph * cgraph) { + for (auto & graph_ptr : this->cache_list) { + if (graph_ptr->matches_cgraph(cgraph)) { + cache_list.remove(graph_ptr); + cache_list.push_front(graph_ptr); + return true; + } + } + return false; + } }; #endif // USE_ACL_GRAPH @@ -318,6 +458,9 @@ struct ggml_cann_rope_cache { if (position_select_index_host) { free(position_select_index_host); } + if (yarn_ramp_cache) { + ACL_CHECK(aclrtFree(yarn_ramp_cache)); + } } bool equal(int64_t theta_scale_length, @@ -370,6 +513,7 @@ struct ggml_cann_rope_cache { float * theta_scale_exp_host = nullptr; int * position_select_index_host = nullptr; void * position_select_index = nullptr; + void * yarn_ramp_cache = nullptr; // sin/cos cache, used only to accelerate first layer on each device void * sin_cache = nullptr; void * cos_cache = nullptr; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index da624c587c..e90759f98b 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2075,162 +2075,6 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); } -#ifdef USE_ACL_GRAPH -/** - * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph. - * - * This function creates a new ggml_cann_graph object and fills its node properties - * (operation type, dimensions, strides, input sources, and operation parameters) - * based on the current ggml computation graph. - * - * Each node in the ggml graph is mapped to a property entry in the new CANN graph: - * - node address - * - operation type - * - shape (ne) and strides (nb) - * - source tensor addresses - * - operation parameters - * - * After initialization, the new graph is pushed into the LRU cache owned by the - * CANN backend context. The cache takes ownership of the graph and manages its - * lifetime (including deletion upon eviction). - * - * @param cann_ctx The CANN backend context containing the graph cache. - * @param cgraph The current ggml computation graph. - */ -static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache). - ggml_cann_graph * new_graph = new ggml_cann_graph(); - new_graph->ggml_graph_properties.resize(cgraph->n_nodes); - - for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) { - ggml_tensor * node = cgraph->nodes[node_idx]; - auto & prop = new_graph->ggml_graph_properties[node_idx]; - - prop.node_address = node->data; - prop.node_op = node->op; - - std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); - std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); - - for (int src = 0; src < GGML_MAX_SRC; ++src) { - if (node->src[src]) { - prop.src_address[src] = node->src[src]->data; - std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); - std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); - } else { - prop.src_address[src] = nullptr; - std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); - std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); - } - } - - memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); - } - - // Insert into the LRU cache (cache takes ownership and will delete it when evicted). - cann_ctx->graph_lru_cache.push(new_graph); -} - -/** - * @brief Check if a ggml tensor node matches a previously captured CANN graph node. - * - * This function compares all relevant fields (address, op type, shape, source inputs, op params) - * to determine whether the current node matches a previously recorded version. - * - * @param node The current ggml tensor node. - * @param graph_node_properties The stored properties of a CANN graph node. - * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. - */ -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, - ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) { - return false; - } - - if (node->op != graph_node_properties->node_op) { - return false; - } - - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != graph_node_properties->ne[i]) { - return false; - } - if (node->nb[i] != graph_node_properties->nb[i]) { - return false; - } - } - - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i]) { - if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) { - return false; - } - - for (int d = 0; d < GGML_MAX_DIMS; d++) { - if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) { - return false; - } - if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) { - return false; - } - } - } else { - if (graph_node_properties->src_address[i] != nullptr) { - return false; - } - } - } - - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { - return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; - } - return true; -} - -/** - * @brief Check whether there is a cached CANN graph that matches the current ggml graph. - * - * This function iterates through the cached CANN graphs stored in the LRU cache and - * compares them against the given ggml computation graph. A match requires that the - * number of nodes is the same and that each node’s properties (operation type, - * dimensions, strides, inputs, and operation parameters) are identical. - * - * If a matching graph is found, it is promoted to the front of the LRU cache and the - * function returns true. Otherwise, the function returns false, indicating that a new - * CANN graph needs to be captured. - * - * @param cann_ctx The CANN backend context containing the graph cache. - * @param cgraph The current ggml computation graph. - * @return true if a matching cached graph exists; false otherwise. - */ -static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache; - for (auto & graph_ptr : lru_cache.cache_list) { - // Skip graphs with a different number of nodes. - if (graph_ptr->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) { - continue; - } - - // Check if all nodes match. - bool all_match = true; - for (int i = 0; i < cgraph->n_nodes; ++i) { - if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) { - all_match = false; - break; - } - } - - if (all_match) { - // update cache_list && renturn graph_ptr - lru_cache.move_to_front(graph_ptr); - return true; - } - } - - return false; -} -#endif // USE_ACL_GRAPH - /** * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @@ -2239,23 +2083,23 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * * * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher. * - * @param cann_ctx The CANN backend context. - * @param cgraph The ggml computation graph. - * @param use_cann_graph Whether to use CANN graph execution. - * @param cann_graph_update_required Whether graph capture is needed due to graph changes. + * @param cann_ctx The CANN backend context. + * @param cgraph The ggml computation graph. + * @param use_cann_graph Whether to use CANN graph execution. + * @param cann_graph_capture_required Whether graph capture is needed due to graph changes. */ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, - bool & use_cann_graph, - bool & cann_graph_update_required) { + bool use_cann_graph, + bool cann_graph_capture_required) { #ifdef USE_ACL_GRAPH - if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture + if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); } #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. - if (!use_cann_graph || cann_graph_update_required) { + if (!use_cann_graph || cann_graph_capture_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2274,9 +2118,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #ifdef USE_ACL_GRAPH if (use_cann_graph) { + GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty()); ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); - if (cann_graph_update_required) { // End CANN graph capture + if (cann_graph_capture_required) { // End CANN graph capture ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); } @@ -2306,7 +2151,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, // calculate rope cache for fist layer in current device. cann_ctx->rope_cache.cached = false; - bool cann_graph_update_required = false; + bool graph_capture_required = false; #ifdef USE_ACL_GRAPH bool use_cann_graph = true; @@ -2331,16 +2176,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, if (use_cann_graph) { // If no matching graph is found, the graph needs to be recaptured. - cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph); - if (cann_graph_update_required) { + graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph); + if (graph_capture_required) { // If no matching graph is found, add a new ACL graph. - add_lru_matched_graph_node_properties(cann_ctx, cgraph); + ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); + cann_ctx->graph_lru_cache.push(new_graph); } } #else bool use_cann_graph = false; #endif // USE_ACL_GRAPH - evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required); + evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required); return GGML_STATUS_SUCCESS; } @@ -2578,8 +2424,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten } } case GGML_OP_CONV_TRANSPOSE_1D: - // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. - return (op->src[0]->ne[0] - 1) <= 255; + return true; case GGML_OP_SCALE: float bias; memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float)); diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 67af1d8ccc..c0f8bcaa37 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run # XX-real == compile CUDA code as device code for this specific architecture @@ -40,6 +41,32 @@ if (CUDAToolkit_FOUND) enable_language(CUDA) + # Replace any 12x-real architectures with 12x{a}-real. FP4 ptx instructions are not available in just 12x + if (GGML_NATIVE) + set(PROCESSED_ARCHITECTURES "") + if (CMAKE_CUDA_ARCHITECTURES_NATIVE) + set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES_NATIVE}) + else() + set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES}) + endif() + foreach(ARCH ${ARCH_LIST}) + if (ARCH MATCHES "^12[0-9](-real|-virtual)?$") + string(REGEX REPLACE "^(12[0-9]).*$" "\\1" BASE_ARCH ${ARCH}) + message(STATUS "Replacing ${ARCH} with ${BASE_ARCH}a-real") + list(APPEND PROCESSED_ARCHITECTURES "${BASE_ARCH}a-real") + else() + list(APPEND PROCESSED_ARCHITECTURES ${ARCH}) + endif() + endforeach() + set(CMAKE_CUDA_ARCHITECTURES ${PROCESSED_ARCHITECTURES}) + else() + foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES}) + if(ARCH MATCHES "^12[0-9]$") + message(FATAL_ERROR "Compute capability ${ARCH} used, use ${ARCH}a or ${ARCH}f for Blackwell specific optimizations") + endif() + endforeach() + endif() + file(GLOB GGML_HEADERS_CUDA "*.cuh") list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9fcb2f9fd2..62e618850b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -50,6 +50,10 @@ #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see +// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms +#define GGML_CUDA_CC_BLACKWELL 1200 +#define GGML_CUDA_CC_RUBIN 1300 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) @@ -246,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) { #define AMPERE_MMA_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN +# define BLACKWELL_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL + #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #define CP_ASYNC_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -316,6 +324,11 @@ static bool cp_async_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } +static bool blackwell_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL && + ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN; +} + static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) return 64; @@ -701,6 +714,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { + const uint8_t sign_bit = (x < 0.0f) << 3; + float ax = fabsf(x) * e; + + // Positive LUT + static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f }; + + int best_i = 0; + float best_err = fabsf(ax - pos_lut[0]); + +#pragma unroll + for (int i = 1; i < 8; ++i) { + const float err = fabsf(ax - pos_lut[i]); + if (err < best_err) { + best_err = err; + best_i = i; + } + } + + return static_cast(best_i | sign_bit); +} + // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. // Precompute mp (m' in the paper) and L such that division // can be computed using a multiply (high 32b of 64b result) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index d2f2def8bd..e82171f9c2 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -5,7 +5,7 @@ #include "ggml.h" #ifdef GGML_CUDA_USE_CUB -# include +# include #endif // GGML_CUDA_USE_CUB template @@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel( const int64_t s01, const int64_t s02, const int64_t s03, const int64_t s1, const int64_t s2, const int64_t s3) { #ifdef GGML_CUDA_USE_CUB - using BlockScan = cub::BlockScan; + using BlockScanT = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ T block_carry; // carry from previous tile + __shared__ typename BlockScanT::TempStorage temp_storage; + __shared__ T block_carry; const int tid = threadIdx.x; + constexpr int UNROLL_FACTOR = 4; + constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR; const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.y; @@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel( } __syncthreads(); - for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { - int64_t idx = start + tid; - T x = (idx < ne00) ? src_row[idx] : T(0); + for (int64_t start = 0; start < ne00; start += TILE_SIZE) { + T items[UNROLL_FACTOR]; + T thread_sum = T(0); - T inclusive; - T block_total; - BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total); - - __syncthreads(); - - T final_val = inclusive + block_carry; - - // store result - if (idx < ne00) { - dst_row[idx] = final_val; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + T val = (idx < ne00) ? src_row[idx] : T(0); + thread_sum += val; + items[i] = thread_sum; } + // Block-wide scan on thread sums + T thread_prefix; + T block_total; + BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total); __syncthreads(); + // Add offset to each item and store + T thread_offset = thread_prefix - thread_sum + block_carry; + #pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + if (idx < ne00) { + dst_row[idx] = items[i] + thread_offset; + } + } + + // Update carry for next tile if (tid == 0) { block_carry += block_total; } - __syncthreads(); } #else @@ -69,7 +80,7 @@ static __global__ void cumsum_cub_kernel( #endif // GGML_CUDA_USE_CUB } -// Fallback kernel implementation (original) +// Fallback kernel implementation template static __global__ void cumsum_kernel( const T * src, T * dst, @@ -86,10 +97,10 @@ static __global__ void cumsum_kernel( const int warps_per_block = blockDim.x / warp_size; extern __shared__ float smem[]; - float * s_vals = smem; - float * s_warp_sums = smem + blockDim.x; - float * s_carry = smem + blockDim.x + warps_per_block; - float * s_chunk_total = s_carry + 1; + float * s_vals = smem; + float * s_warp_sums = smem + blockDim.x; + float * s_carry = smem + blockDim.x + warps_per_block; + float * s_chunk_total = s_carry + 1; // Initialize carry if (tid == 0) { @@ -107,21 +118,39 @@ static __global__ void cumsum_kernel( const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; - for (int64_t start = 0; start < ne00; start += blockDim.x) { - int64_t idx = start + tid; - float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f; + // register blocking: process 4 elements per thread to hide latency + // and reduce synchronization overhead + constexpr int num_unroll = 4; + T temp[num_unroll]; - // 1. Warp inclusive scan + for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) { + int64_t idx = i + tid * num_unroll; + + // thread local sequential scan + temp[0] = (idx < ne00 ? src_row[idx] : T(0)); +#pragma unroll + for (int64_t j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } else { + temp[j] += 0; + } + } + + // last emenent is sum of all values assigned to thread + float val = (idx < ne00) ? ggml_cuda_cast(temp[num_unroll - 1]) : 0.0f; + + // Warp inclusive scan val = warp_prefix_inclusive_sum(val); s_vals[tid] = val; - // Store warp total if (lane == warp_size - 1) { s_warp_sums[warp] = val; } __syncthreads(); - // 2. Exclusive scan of warp sums (warp 0 only) + // Exclusive scan of warp sums (warp 0 only) if (warp == 0) { float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; float inc = warp_prefix_inclusive_sum(w); @@ -134,12 +163,17 @@ static __global__ void cumsum_kernel( } __syncthreads(); + // write back results float carry = *s_carry; - float final_val = s_vals[tid] + s_warp_sums[warp] + carry; - if (idx < ne00) { - dst_row[idx] = ggml_cuda_cast(final_val); + // calculate sum offset for this thread + float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int32_t j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + ggml_cuda_cast(final_val_offset); + } } - __syncthreads(); // Update carry for next chunk if (tid == 0) { @@ -177,7 +211,7 @@ static void cumsum_cuda( const int warps_per_block = block_size / warp_size; const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); - if (use_cub) { + if (use_cub && ne00 >= 1024) { cumsum_cub_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 3268dadfe8..df9eed7117 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -900,6 +900,27 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { +#ifdef BLACKWELL_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + float * Dxi = (float *) D.x; + + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); +#else + GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); +#endif // BLACKWELL_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef TURING_MMA_AVAILABLE diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index f7a2cbca90..6156dcdae7 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "mmq.cuh" #include "quantize.cuh" #include "mmid.cuh" @@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q( const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_CDNA(cc); + // TODO: tighter pool buffer size vs q8 path + const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); @@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, - ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + if (use_native_mxfp4) { + static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); + quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + + } else { + quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + } CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + // Stride depends on quantization format + const int64_t s12 = use_native_mxfp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / + (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) + : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; const mmq_args args = { @@ -175,12 +191,19 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[2] / ts_src1; - quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, - ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + + if (use_native_mxfp4) { + quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } else { + quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index fa8a72c9c1..63451ffab7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -11,6 +11,7 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 +#define MMQ_ITER_K_MXFP4_FP4 512 #define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); @@ -44,8 +45,15 @@ struct block_q8_1_mmq { }; int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; + +struct block_fp4_mmq { + uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values +}; + static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size"); static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { @@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) { ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); } +static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { +#if defined(BLACKWELL_MMA_AVAILABLE) + return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; +#else + return MMQ_ITER_K; +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIP) #if defined(RDNA1) @@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) @@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { @@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; @@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { } // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) -#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1) +#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K static int mmq_get_granularity_host(const int mmq_x, const int cc) { if (amd_mfma_available(cc) || amd_wmma_available(cc)) { @@ -761,6 +782,50 @@ template static __device__ __forceinline__ void loa } } +template +static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + int * x_qs = (int *) x_tile; + uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + + const int txi = threadIdx.x; + + constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4); + + constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx; + + // quantize_mxfp4_mmq permutes nibbles to match the quantized format + const int k0 = kbx * 4; + memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16); + + // Load E8M0 scales: pack 2 consecutive scales into one uint32 + if (kbx % 2 == 0) { + uint32_t e = bxi->e; + e |= ((bxi + 1)->e << 8); + x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e; + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -931,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } +template +static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); + + // Match layout from load_tiles_mxfp4_fp4 + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 + tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; + uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; + + // Block scale + // Each thread has to point to a 4 byte scale value + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, + MMQ_MMA_TILE_X_K_FP4); + + // based on block-scaling document, 2 threads in each quad need to supply to the scale value + const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + scaleA[n][k01 / (2 * QI_MXFP4)] = + *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { + tile_B B; + uint32_t scaleB; // 2xN scales + + load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); + + scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + + mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -3109,8 +3246,13 @@ struct mmq_type_traits { template struct mmq_type_traits { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma; +#else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; +#endif // BLACKWELL_MMA_AVAILABLE static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -3243,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - constexpr int blocks_per_iter = MMQ_ITER_K / qk; +#if defined(BLACKWELL_MMA_AVAILABLE) + // FP4 tile stores 8 blocks + constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; +#else + constexpr int ne_block = 4 * QK8_1; +#endif // defined(BLACKWELL_MMA_AVAILABLE) + + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int); + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); - { - const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz; #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; @@ -3267,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( __syncthreads(); { - const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; @@ -3401,8 +3552,10 @@ static __global__ void mul_mat_q( } #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + constexpr int ITER_K = get_iter_k(type); + const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = MMQ_ITER_K / qk; + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; @@ -3463,7 +3616,7 @@ static __global__ void mul_mat_q( __syncthreads(); } - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; @@ -3530,7 +3683,7 @@ static __global__ void mul_mat_q( __syncthreads(); } - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; @@ -3553,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup( const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int blocks_per_iter = MMQ_ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + + constexpr int blocks_per_iter = ITER_K / qk; const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int nwarps = mmq_get_nwarps_device(); @@ -3711,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq)); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 5117f9ffc0..a8c68e44b1 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -47,6 +47,131 @@ static __global__ void quantize_q8_1( y[ib].ds = make_half2(d, sum); } +__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { + if (!(amax > 0.0f)) { + return 0; + } + + // FP4 E2M1: max exponent (unbiased) is 2. + constexpr int FP4_E2M1_EMAX = 2; + + const float e = log2f(amax); + + // "even" -> round-to-nearest integer, ties-to-even + const int e_int = __float2int_rn(e); + + const int shared_exp = e_int - FP4_E2M1_EMAX; + + int biased = shared_exp + 127; + + biased = max(biased, 0); + biased = min(biased, 254); + + return static_cast(biased); +} + +// quantize values in the format mxfp4 is stored which is interleaved nibbles +// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 +static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, + const int32_t * __restrict__ ids, + void * __restrict__ vy, + const int64_t ne00, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t ne0, + const int ne1, + const int ne2) { + constexpr int vals_per_scale = 32; + constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values + + const int warp_id = threadIdx.y; + const int lane_id_32 = threadIdx.x; + + const int nwarps = blockDim.y; + + const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp; + + if (warp_start_offset >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + block_fp4_mmq * y = (block_fp4_mmq *) vy; + + const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values + const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size)); + const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; + const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; + + const int group_id = lane_id_32 / 4; + const int lane_in_group = lane_id_32 % 4; + const int base = group_id * 2; + char2 * yqs2 = (char2 *) y[ib].qs; + + const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01; + + uint8_t scales[2]; + +#pragma unroll + for (int b = 0; b < 2; ++b) { + const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32; + const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f; + + float amax = fabsf(xi); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + } + + const uint8_t e = compute_e8m0_scale(amax); + scales[b] = e; + const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e)); + +#if CUDART_VERSION >= 12080 + const float scaled_val = xi * inv_s; + + const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); + const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); + const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); + const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3)); + + yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed; + } +#else + // Fallback: manual FP4 conversion using LUT + const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); + + const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE); + const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE); + const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE); + const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + char2 q; + q.x = (q_hi_0 << 4) | q_lo_0; + q.y = (q_hi_1 << 4) | q_lo_1; + yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q; + } +#endif // CUDART_VERSION >= 12080 + } + + if (lane_id_32 == 0) { + // Store 2 scales packed into 1 uint32 + y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0]; + } +} + template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, @@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda( break; } } + +void quantize_mmq_mxfp4_cuda(const float * x, + const int32_t * ids, + void * vy, + [[maybe_unused]] const ggml_type type_src0, + const int64_t ne00, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + cudaStream_t stream) { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +} diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 725ab52443..6a91df6357 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda( const float * x, const int32_t * ids, void * vy, ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); + +void quantize_mmq_mxfp4_cuda(const float * x, + const int32_t * ids, + void * vy, + ggml_type type_src0, + int64_t ne00, + int64_t s01, + int64_t s02, + int64_t s03, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 3b3086778e..ba032cfab4 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -10,6 +10,10 @@ #include #endif // CUDART_VERSION >= 12050 +#if CUDART_VERSION >= 12080 +#include +#endif // CUDART_VERSION >= 12080 + #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a524adbe0c..1459b2608e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -379,18 +379,18 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} uint32_t HSK, HSV; - bool small_rows; + bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); } }; @@ -2582,10 +2582,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { if (hsv >= 192) { return 2; - } else if ((hsv | hsk) & 8) { + } else if ((hsv | hsk) & 8 || small_cache) { return 4; } else { return 8; @@ -2607,9 +2607,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) { } } -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { GGML_UNUSED(clamp); - GGML_UNUSED(hsv); if (path == FA_SCALAR) { if (small_rows) { @@ -2618,9 +2617,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; } else { - return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; } } } @@ -2649,8 +2648,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -2992,11 +2991,11 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -3006,7 +3005,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) ? scalar_flash_attention_workgroup_size : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -3021,21 +3020,22 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t HSK = fa.first.HSK; \ uint32_t HSV = fa.first.HSV; \ bool small_rows = fa.first.small_rows; \ + bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -8008,11 +8008,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -8136,6 +8136,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; + const bool small_cache = nek1 < 1024; + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). uint32_t max_gqa; @@ -8143,7 +8145,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -8177,7 +8179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { small_rows = true; } @@ -8193,7 +8195,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8205,7 +8207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); vk_pipeline pipeline = nullptr; @@ -13716,6 +13718,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; @@ -13745,6 +13748,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev } static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; @@ -13760,6 +13764,8 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even } ggml_vk_wait_events(transfer_ctx, {vkev->event}); + ggml_vk_ctx_end(transfer_ctx); + ctx->transfer_ctx.reset(); } // TODO: enable async and synchronize @@ -14519,6 +14525,7 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe } static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")"); ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); vk_event *vkev = (vk_event *)event->context; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 41d3bd4faf..27578daaf9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -449,6 +449,8 @@ class MODEL_ARCH(IntEnum): RND1 = auto() PANGU_EMBED = auto() MISTRAL3 = auto() + MIMO2 = auto() + LLAMA_EMBED = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -844,6 +846,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", + MODEL_ARCH.MIMO2: "mimo2", + MODEL_ARCH.LLAMA_EMBED: "llama-embed", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -3196,6 +3200,46 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.MIMO2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_SINKS, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], + MODEL_ARCH.LLAMA_EMBED: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 276720fcde..1690d991f2 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -320,6 +320,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_SINKS: ( "model.layers.{bid}.self_attn.sinks", # openai-moe + "model.layers.{bid}.self_attn.attention_sink_bias", # mimov2 ), MODEL_TENSOR.ATTN_GATE: ( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4ca8974916..1e155534bd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -88,6 +88,7 @@ add_library(llama models/llama-iswa.cpp models/llama.cpp models/mamba.cpp + models/mimo2-iswa.cpp models/minicpm3.cpp models/minimax-m2.cpp models/modern-bert.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 80f44ae1bf..75013d8d33 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -115,6 +115,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -500,6 +502,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_DECI: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: return { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, @@ -2188,6 +2191,27 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, }; + case LLM_ARCH_MIMO2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { diff --git a/src/llama-arch.h b/src/llama-arch.h index a53bc39d18..27bdedc83c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -119,6 +119,8 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MIMO2, + LLM_ARCH_LLAMA_EMBED, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f6e95b5d2a..42def73f06 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -123,10 +123,11 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == true, then layer il is SWA - // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // if swa_layers[il] == 1, then layer il is SWA + // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense - std::array swa_layers; + // note: using uint32_t type for compatibility reason + std::array swa_layers; // for State Space Models uint32_t ssm_d_conv = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0d5bcc64fe..69075742c9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -130,6 +130,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -606,7 +607,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -630,6 +631,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_EMBED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2338,6 +2340,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MIMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_310B_A15B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2652,6 +2670,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6646,6 +6665,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); } } break; + case LLM_ARCH_MIMO2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // non-MoE branch + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE branch + int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7269,16 +7326,20 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { switch (arch) { case LLM_ARCH_LLAMA: { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } break; case LLM_ARCH_LLAMA4: { if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } else { llm = std::make_unique(*this, params); } } break; + case LLM_ARCH_LLAMA_EMBED: + { + llm = std::make_unique>(*this, params); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params); @@ -7704,6 +7765,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MIMO2: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7874,6 +7939,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -7933,6 +7999,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_MIMO2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 7f560d462f..9c00eec75f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -123,6 +123,7 @@ enum llm_type { LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_E2B, LLM_TYPE_E4B, diff --git a/src/models/llama.cpp b/src/models/llama.cpp index ab7fd5d050..42b5fcdf42 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14,7 +15,14 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); + using inp_attn_type = std::conditional_t; + + inp_attn_type * inp_attn = nullptr; + if constexpr (embed) { + inp_attn = build_attn_inp_no_cache(); + } else { + inp_attn = build_attn_inp_kv(); + } const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -145,11 +153,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); + if constexpr (!embed) { + // lm_head + cur = build_lora_mm(model.output, cur); - cb(cur, "result_output", -1); - res->t_logits = cur; + cb(cur, "result_output", -1); + res->t_logits = cur; + } ggml_build_forward_expand(gf, cur); } + +template struct llm_build_llama; +template struct llm_build_llama; diff --git a/src/models/mimo2-iswa.cpp b/src/models/mimo2-iswa.cpp new file mode 100644 index 0000000000..edc87cc9f0 --- /dev/null +++ b/src/models/mimo2-iswa.cpp @@ -0,0 +1,123 @@ + +#include "models.h" + +llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + uint32_t n_head_l = hparams.n_head(il); + uint32_t n_head_kv_l = hparams.n_head_kv(il); + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // self_attention + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * sinks = model.layers[il].attn_sinks; + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, + 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 53a5810659..dd0e286eda 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -303,6 +303,7 @@ struct llm_build_llada_moe : public llm_graph_context { llm_build_llada_moe(const llama_model & model, const llm_graph_params & params); }; +template struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params); }; @@ -315,6 +316,10 @@ struct llm_build_mamba : public llm_graph_context_mamba { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_mimo2_iswa : public llm_graph_context { + llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_minicpm3 : public llm_graph_context { llm_build_minicpm3(const llama_model & model, const llm_graph_params & params); }; diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp index 2c113c453e..de47763d3e 100644 --- a/tools/fit-params/fit-params.cpp +++ b/tools/fit-params/fit-params.cpp @@ -35,7 +35,7 @@ int main(int argc, char ** argv) { } LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__); - std::this_thread::sleep_for(10ms); // to avoid a race between stderr and stdout + common_log_flush(common_log_main()); printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers); size_t nd = llama_max_devices(); diff --git a/tools/server/README.md b/tools/server/README.md index 1ae5eae4c6..7d2f6f798e 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1486,6 +1486,7 @@ The precedence rule for preset options is as follows: We also offer additional options that are exclusive to presets (these aren't treated as command-line arguments): - `load-on-startup` (boolean): Controls whether the model loads automatically when the server starts +- `stop-timeout` (int, seconds): After requested unload, wait for this many seconds before forcing termination (default: 10) ### Routing requests @@ -1574,8 +1575,7 @@ Payload: ```json { - "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", - "extra_args": ["-n", "128", "--top-k", "4"] + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" } ``` diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 08a0da5c87..cb7e70455a 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -34,6 +34,8 @@ #include #endif +#define DEFAULT_STOP_TIMEOUT 10 // seconds + #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" #define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" @@ -203,13 +205,14 @@ void server_models::load_models() { // convert presets to server_model_meta and add to mapping for (const auto & preset : final_presets) { server_model_meta meta{ - /* preset */ preset.second, - /* name */ preset.first, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* exit_code */ 0 + /* preset */ preset.second, + /* name */ preset.first, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0, + /* stop_timeout */ DEFAULT_STOP_TIMEOUT, }; add_model(std::move(meta)); } @@ -227,6 +230,20 @@ void server_models::load_models() { } } + // handle custom stop-timeout option + for (auto & [name, inst] : mapping) { + std::string val; + if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) { + try { + inst.meta.stop_timeout = std::stoi(val); + } catch (...) { + SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n", + val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT); + inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT; + } + } + } + // load any autoload models std::vector models_to_load; for (const auto & [name, inst] : mapping) { @@ -362,7 +379,7 @@ void server_models::unload_lru() { int64_t lru_last_used = ggml_time_ms(); size_t count_active = 0; { - std::lock_guard lk(mutex); + std::unique_lock lk(mutex); for (const auto & m : mapping) { if (m.second.meta.is_active()) { count_active++; @@ -376,6 +393,13 @@ void server_models::unload_lru() { if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); unload(lru_model_name); + // wait for unload to complete + { + std::unique_lock lk(mutex); + cv.wait(lk, [this, &lru_model_name]() { + return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED; + }); + } } } @@ -436,38 +460,83 @@ void server_models::load(const std::string & name) { // start a thread to manage the child process // captured variables are guaranteed to be destroyed only after the thread is joined - inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() { - // read stdout/stderr and forward to main server log - bool state_received = false; // true if child state received - FILE * p_stdout_stderr = subprocess_stdout(child_proc.get()); - if (p_stdout_stderr) { - char buffer[4096]; - while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) { - LOG("[%5d] %s", port, buffer); - if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) { - // child process is ready - this->update_status(name, SERVER_MODEL_STATUS_LOADED); - state_received = true; + inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() { + FILE * stdin_file = subprocess_stdin(child_proc.get()); + FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr + + std::thread log_thread([&]() { + // read stdout/stderr and forward to main server log + // also handle status report from child process + bool state_received = false; // true if child state received + if (stdout_file) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { + LOG("[%5d] %s", port, buffer); + if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) { + // child process is ready + this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); + state_received = true; + } } + } else { + SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str()); } - } else { - SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str()); - } + }); + + std::thread stopping_thread([&]() { + // thread to monitor stopping signal + auto is_stopping = [this, &name]() { + return this->stopping_models.find(name) != this->stopping_models.end(); + }; + { + std::unique_lock lk(this->mutex); + this->cv_stop.wait(lk, is_stopping); + } + SRV_INF("stopping model instance name=%s\n", name.c_str()); + // send interrupt to child process + fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT); + fflush(stdin_file); + // wait to stop gracefully or timeout + int64_t start_time = ggml_time_ms(); + while (true) { + std::unique_lock lk(this->mutex); + if (!is_stopping()) { + return; // already stopped + } + int64_t elapsed = ggml_time_ms() - start_time; + if (elapsed >= stop_timeout * 1000) { + // timeout, force kill + SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout); + subprocess_terminate(child_proc.get()); + return; + } + this->cv_stop.wait_for(lk, std::chrono::seconds(1)); + } + }); + // we reach here when the child process exits + // note: we cannot join() prior to this point because it will close stdin_file + if (log_thread.joinable()) { + log_thread.join(); + } + + // stop the timeout monitoring thread + { + std::lock_guard lk(this->mutex); + stopping_models.erase(name); + cv_stop.notify_all(); + } + if (stopping_thread.joinable()) { + stopping_thread.join(); + } + + // get the exit code int exit_code = 0; subprocess_join(child_proc.get(), &exit_code); subprocess_destroy(child_proc.get()); - // update PID and status - { - std::lock_guard lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - auto & meta = it->second.meta; - meta.exit_code = exit_code; - meta.status = SERVER_MODEL_STATUS_UNLOADED; - } - cv.notify_all(); - } + + // update status and exit code + this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code); SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); @@ -488,22 +557,14 @@ void server_models::load(const std::string & name) { cv.notify_all(); } -static void interrupt_subprocess(FILE * stdin_file) { - // because subprocess.h does not provide a way to send SIGINT, - // we will send a command to the child process to exit gracefully - if (stdin_file) { - fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT); - fflush(stdin_file); - } -} - void server_models::unload(const std::string & name) { std::lock_guard lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { if (it->second.meta.is_active()) { SRV_INF("unloading model instance name=%s\n", name.c_str()); - interrupt_subprocess(it->second.stdin_file); + stopping_models.insert(name); + cv_stop.notify_all(); // status change will be handled by the managing thread } else { SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); @@ -518,7 +579,8 @@ void server_models::unload_all() { for (auto & [name, inst] : mapping) { if (inst.meta.is_active()) { SRV_INF("unloading model instance name=%s\n", name.c_str()); - interrupt_subprocess(inst.stdin_file); + stopping_models.insert(name); + cv_stop.notify_all(); // status change will be handled by the managing thread } // moving the thread to join list to avoid deadlock @@ -532,16 +594,15 @@ void server_models::unload_all() { } } -void server_models::update_status(const std::string & name, server_model_status status) { - // for now, we only allow updating to LOADED status - if (status != SERVER_MODEL_STATUS_LOADED) { - throw std::runtime_error("invalid status value"); - } - auto meta = get_meta(name); - if (meta.has_value()) { - meta->status = status; - update_meta(name, meta.value()); +void server_models::update_status(const std::string & name, server_model_status status, int exit_code) { + std::unique_lock lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + auto & meta = it->second.meta; + meta.status = status; + meta.exit_code = exit_code; } + cv.notify_all(); } void server_models::wait_until_loaded(const std::string & name) { @@ -568,6 +629,7 @@ bool server_models::ensure_model_loaded(const std::string & name) { load(name); } + // for loading state SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); wait_until_loaded(name); @@ -795,7 +857,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } - if (model->status != SERVER_MODEL_STATUS_LOADED) { + if (!model->is_active()) { res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 3e1868c27c..7e33537536 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -9,6 +9,7 @@ #include #include #include +#include /** * state diagram: @@ -56,6 +57,7 @@ struct server_model_meta { int64_t last_used = 0; // for LRU unloading std::vector args; // args passed to the model instance, will be populated by render_args() int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) + int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown bool is_active() const { return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; @@ -83,6 +85,10 @@ private: std::condition_variable cv; std::map mapping; + // for stopping models + std::condition_variable cv_stop; + std::set stopping_models; + common_preset_context ctx_preset; common_params base_params; @@ -119,7 +125,7 @@ public: void unload_all(); // update the status of a model instance (thread-safe) - void update_status(const std::string & name, server_model_status status); + void update_status(const std::string & name, server_model_status status, int exit_code); // wait until the model instance is fully loaded (thread-safe) // return when the model is loaded or failed to load