diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 02867e4fdb..ec8f144c6e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3309,107 +3309,260 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor * MoE architectures with potential sparse expert routing. */ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - // TODO: Use aclnnGroupedMatMul //dst [M, K, N, 1] ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 ggml_tensor * ids = dst->src[2]; //ids [K, N] - GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT(dst->ne[3] == 1); - // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K + int64_t batch = src1->ne[2]; + GGML_ASSERT(batch == ids->ne[1]); - std::vector ids_host(ggml_nbytes(ids)); - ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream())); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); - - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; - - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; - - const enum ggml_type type = dst->src[0]->type; - float weight_elem_size; + const enum ggml_type type = src0->type; + float weight_elem_size; if (type == GGML_TYPE_Q4_0) { weight_elem_size = float(sizeof(uint8_t)) / 2; } else if (type == GGML_TYPE_Q8_0) { weight_elem_size = float(sizeof(uint8_t)); } else { - GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); + GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0"); } - // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = weight_elem_size; - src0_row.nb[1] = weight_elem_size * ne00; - src0_row.nb[2] = weight_elem_size * ne00; - src0_row.nb[3] = weight_elem_size * ne00; - size_t weight_stride = ne00 * ne01 * weight_elem_size; - size_t weight_size = weight_stride * ne02 * ne03; + // Calculate memory layout + size_t weight_stride = src0->ne[0] * src0->ne[1] * weight_elem_size; + size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3]; - // scale [D, M, 1, 1] -> scale && permute size_t scale_elem_size = sizeof(uint16_t); - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + char * scale_offset = (char *) src0->data + weight_size; - // src1_row [D, 1, 1, 1] -> input - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; + // Allocate temporary buffers for selected weights and scales + size_t export_weight_size = src0->ne[0] * src0->ne[1] * ids->ne[0] * weight_elem_size; + ggml_cann_pool_alloc export_weight_allocator(ctx.pool(), export_weight_size); + void * export_weight_ptr = export_weight_allocator.get(); - // dst_row [M, 1, 1, 1] -> out - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; + size_t export_scale_size = (src0->ne[0] / QK8_0) * src0->ne[1] * ids->ne[0] * scale_elem_size; + ggml_cann_pool_alloc export_scale_allocator(ctx.pool(), export_scale_size); + void * export_scale_ptr = export_scale_allocator.get(); - //create weight for one row - ggml_cann_pool_alloc weight_allocator(ctx.pool()); - void * weight_buffer = weight_allocator.alloc(nb02); - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + // Prepare input buffer (convert to F16 if needed) + size_t input_elem_size = sizeof(uint16_t); + ggml_cann_pool_alloc input_allocator(ctx.pool()); + void * input_buffer = src1->data; - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; + if (src1->type != GGML_TYPE_F16) { + size_t total_input_size = input_elem_size; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + total_input_size *= src1->ne[i]; + } + input_buffer = input_allocator.alloc(total_input_size); - int64_t i1 = id; - int64_t i2 = i12; + acl_tensor_ptr acl_src1_tensor = ggml_cann_create_tensor(src1); - void * src0_tmp_ptr = src0_original + i02 * weight_stride; - void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride; - void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12; - void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2; + int64_t input_cast_ne[GGML_MAX_DIMS]; + size_t input_cast_nb[GGML_MAX_DIMS]; - // mem cpy - ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - void * scale_buffer = (char *) weight_buffer + weight_stride; - ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + input_cast_ne[i] = src1->ne[i]; + } - src0_row.data = weight_buffer; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; - dst_row.src[0] = &src0_row; - dst_row.src[1] = &src1_row; + input_cast_nb[0] = input_elem_size; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1]; + } - ggml_cann_mul_mat(ctx, &dst_row); + acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor( + input_buffer, ACL_FLOAT16, input_elem_size, + input_cast_ne, input_cast_nb, GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_src1_tensor.get(), acl_input_tensor.get(), ACL_FLOAT16); + } + + // Prepare output buffer (use temp buffer if not F16) + size_t output_elem_size = sizeof(uint16_t); + ggml_cann_pool_alloc output_allocator(ctx.pool()); + void * output_buffer = dst->data; + + if (dst->type != GGML_TYPE_F16) { + size_t total_output_size = output_elem_size; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + total_output_size *= dst->ne[i]; + } + output_buffer = output_allocator.alloc(total_output_size); + } + + // Process each batch + for (int64_t i = 0; i < batch; i++) { + // Create index tensor for this batch + acl_tensor_ptr select_index = ggml_cann_create_tensor( + ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]); + + // IndexSelect for quantized weights (using int8 type) + int64_t weight_ne_for_select[3]; + if (type == GGML_TYPE_Q4_0) { + weight_ne_for_select[0] = src0->ne[0] / 2; // 2 Q4_0 values per byte + } else { + weight_ne_for_select[0] = src0->ne[0]; // Q8_0 + } + weight_ne_for_select[1] = src0->ne[1]; + weight_ne_for_select[2] = src0->ne[2]; + + size_t weight_nb_for_select[3]; + weight_nb_for_select[0] = sizeof(int8_t); + weight_nb_for_select[1] = weight_nb_for_select[0] * weight_ne_for_select[0]; + weight_nb_for_select[2] = weight_nb_for_select[1] * weight_ne_for_select[1]; + + acl_tensor_ptr export_weight = ggml_cann_create_tensor( + src0->data, ACL_INT8, sizeof(int8_t), + weight_ne_for_select, weight_nb_for_select, 3); + + int64_t select_export_weight_ne[3] = { + weight_ne_for_select[0], + weight_ne_for_select[1], + ids->ne[0] + }; + size_t select_export_weight_nb[3]; + select_export_weight_nb[0] = sizeof(int8_t); + select_export_weight_nb[1] = select_export_weight_nb[0] * select_export_weight_ne[0]; + select_export_weight_nb[2] = select_export_weight_nb[1] * select_export_weight_ne[1]; + + acl_tensor_ptr select_export_weight = ggml_cann_create_tensor( + export_weight_ptr, ACL_INT8, sizeof(int8_t), + select_export_weight_ne, select_export_weight_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, + export_weight.get(), 0, select_index.get(), select_export_weight.get()); + + // IndexSelect for scales + int64_t scale_ne[3] = { + src0->ne[0] / QK8_0, + src0->ne[1], + src0->ne[2] + }; + size_t scale_nb[3]; + scale_nb[0] = scale_elem_size; + scale_nb[1] = scale_nb[0] * scale_ne[0]; + scale_nb[2] = scale_nb[1] * scale_ne[1]; + + acl_tensor_ptr export_scale = ggml_cann_create_tensor( + scale_offset, ACL_FLOAT16, scale_elem_size, + scale_ne, scale_nb, 3); + + int64_t select_export_scale_ne[3] = { + scale_ne[0], + scale_ne[1], + ids->ne[0] + }; + size_t select_export_scale_nb[3]; + select_export_scale_nb[0] = scale_elem_size; + select_export_scale_nb[1] = select_export_scale_nb[0] * select_export_scale_ne[0]; + select_export_scale_nb[2] = select_export_scale_nb[1] * select_export_scale_ne[1]; + + acl_tensor_ptr select_export_scale = ggml_cann_create_tensor( + export_scale_ptr, ACL_FLOAT16, scale_elem_size, + select_export_scale_ne, select_export_scale_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, + export_scale.get(), 0, select_index.get(), select_export_scale.get()); + + // IndexSelect output is [D, M, K] in contiguous layout + // For WeightQuantBatchMatmulV2, we need each expert as [M, D] with M major stride + for (int64_t k = 0; k < ids->ne[0]; k++) { + // Input offset: if src1->ne[1] == 1, broadcast (all k use same input); otherwise each k has its own input + size_t input_offset = (i * src1->ne[1] + (src1->ne[1] == 1 ? 0 : k)) * src1->ne[0] * input_elem_size; + size_t output_offset = (i * dst->ne[1] + k) * dst->ne[0] * output_elem_size; + + // Create view for the k-th expert weight from [D, M, K] -> [M, D] + // Data layout in memory is [D0M0, D0M1, ..., D0M_{M-1}, D1M0, D1M1, ...] + // We need [M, D] format with stride[0]=D*elemsize, stride[1]=elemsize + int64_t weight_view_ne[2] = { + src0->ne[1], // M = src0->ne[1] + src0->ne[0] // D = src0->ne[0] (adjusted for Q4_0/Q8_0) + }; + float weight_view_nb[2] = { + src0->ne[0] * weight_elem_size, // M stride: one row = D * elemsize + weight_elem_size // D stride: one element + }; + size_t weight_view_offset = k * select_export_weight_nb[2]; + + acl_tensor_ptr weight_view = ggml_cann_create_tensor( + export_weight_ptr, ggml_cann_type_mapping(type), weight_elem_size, + weight_view_ne, weight_view_nb, 2, + ACL_FORMAT_ND, weight_view_offset); + + // Create view for the k-th expert scale from [D, M, K] -> [M, D] + int64_t scale_view_ne[2] = { + select_export_scale_ne[1], // M = src0->ne[1] + select_export_scale_ne[0] // D = src0->ne[0] / QK8_0 + }; + size_t scale_view_nb[2] = { + select_export_scale_nb[1], // M stride + select_export_scale_nb[0] // D stride + }; + size_t scale_view_offset = k * select_export_scale_nb[2]; + + acl_tensor_ptr scale_view = ggml_cann_create_tensor( + export_scale_ptr, ACL_FLOAT16, scale_elem_size, + scale_view_ne, scale_view_nb, 2, + ACL_FORMAT_ND, scale_view_offset); + + // Prepare input tensor [D, 1] + int64_t active_tensor_ne[2] = { src1->ne[0], 1 }; + size_t active_tensor_nb[2] = { input_elem_size, src1->ne[0] * input_elem_size }; + + acl_tensor_ptr active_tensor = ggml_cann_create_tensor( + input_buffer, ACL_FLOAT16, input_elem_size, + active_tensor_ne, active_tensor_nb, 2, + ACL_FORMAT_ND, input_offset); + + // Prepare output tensor [M, 1] + int64_t dst_ne[2] = { dst->ne[0], 1 }; + size_t dst_nb[2] = { output_elem_size, dst->ne[0] * output_elem_size }; + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + output_buffer, ACL_FLOAT16, output_elem_size, + dst_ne, dst_nb, 2, + ACL_FORMAT_ND, output_offset); + + // Call WeightQuantBatchMatmulV2 + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, + active_tensor.get(), + weight_view.get(), + scale_view.get(), + nullptr, + nullptr, + nullptr, + nullptr, + QK8_0, + acl_dst.get()); } } - return; + + // Cast output back to target type if needed + if (dst->type != GGML_TYPE_F16) { + int64_t output_cast_ne[GGML_MAX_DIMS]; + size_t output_cast_nb[GGML_MAX_DIMS]; + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + output_cast_ne[i] = dst->ne[i]; + } + + output_cast_nb[0] = output_elem_size; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1]; + } + + acl_tensor_ptr acl_output_tensor = ggml_cann_create_tensor( + output_buffer, ACL_FLOAT16, output_elem_size, + output_cast_ne, output_cast_nb, GGML_MAX_DIMS); + + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + + aclnn_cast(ctx, acl_output_tensor.get(), acl_dst_tensor.get(), + ggml_cann_type_mapping(dst->type)); + } } void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {