vulkan: remove a couple unnecessary switches (#17419)
This commit is contained in:
parent
4949ac0f18
commit
54d83bbe85
|
|
@ -11381,13 +11381,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
|
static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
|
||||||
|
|
||||||
// Returns true if node has enqueued work into the queue, false otherwise
|
// Returns true if node has enqueued work into the queue, false otherwise
|
||||||
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
||||||
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
|
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
|
||||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||||
if (ggml_is_empty(node) || !node->buffer) {
|
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -11399,132 +11399,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
ggml_tensor * src2 = node->src[2];
|
ggml_tensor * src2 = node->src[2];
|
||||||
ggml_tensor * src3 = node->src[3];
|
ggml_tensor * src3 = node->src[3];
|
||||||
|
|
||||||
switch (node->op) {
|
if (node->op == GGML_OP_ADD) {
|
||||||
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
||||||
case GGML_OP_RESHAPE:
|
if (next_node_idx < cgraph->n_nodes &&
|
||||||
case GGML_OP_VIEW:
|
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
||||||
case GGML_OP_PERMUTE:
|
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
||||||
case GGML_OP_TRANSPOSE:
|
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
||||||
case GGML_OP_NONE:
|
ctx->device->add_rms_fusion) {
|
||||||
return false;
|
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
||||||
case GGML_OP_UNARY:
|
ctx->do_add_rms_partials_offset_calculation = true;
|
||||||
switch (ggml_get_unary_op(node)) {
|
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
|
||||||
case GGML_UNARY_OP_EXP:
|
ctx->do_add_rms_partials = true;
|
||||||
case GGML_UNARY_OP_SILU:
|
|
||||||
case GGML_UNARY_OP_GELU:
|
|
||||||
case GGML_UNARY_OP_GELU_ERF:
|
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
|
||||||
case GGML_UNARY_OP_RELU:
|
|
||||||
case GGML_UNARY_OP_NEG:
|
|
||||||
case GGML_UNARY_OP_TANH:
|
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
|
||||||
case GGML_UNARY_OP_ABS:
|
|
||||||
case GGML_UNARY_OP_SOFTPLUS:
|
|
||||||
case GGML_UNARY_OP_STEP:
|
|
||||||
case GGML_UNARY_OP_ROUND:
|
|
||||||
case GGML_UNARY_OP_CEIL:
|
|
||||||
case GGML_UNARY_OP_FLOOR:
|
|
||||||
case GGML_UNARY_OP_TRUNC:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_GLU:
|
|
||||||
switch (ggml_get_glu_op(node)) {
|
|
||||||
case GGML_GLU_OP_GEGLU:
|
|
||||||
case GGML_GLU_OP_REGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU_OAI:
|
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_ADD:
|
|
||||||
{
|
|
||||||
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
|
||||||
if (next_node_idx < cgraph->n_nodes &&
|
|
||||||
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
|
||||||
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
|
||||||
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
|
||||||
ctx->device->add_rms_fusion) {
|
|
||||||
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
|
||||||
ctx->do_add_rms_partials_offset_calculation = true;
|
|
||||||
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
|
|
||||||
ctx->do_add_rms_partials = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} break;
|
}
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_REPEAT_BACK:
|
|
||||||
case GGML_OP_GET_ROWS:
|
|
||||||
case GGML_OP_ADD_ID:
|
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_SUB:
|
|
||||||
case GGML_OP_MUL:
|
|
||||||
case GGML_OP_DIV:
|
|
||||||
case GGML_OP_ADD1:
|
|
||||||
case GGML_OP_ARANGE:
|
|
||||||
case GGML_OP_FILL:
|
|
||||||
case GGML_OP_CONCAT:
|
|
||||||
case GGML_OP_UPSCALE:
|
|
||||||
case GGML_OP_SCALE:
|
|
||||||
case GGML_OP_SQR:
|
|
||||||
case GGML_OP_SQRT:
|
|
||||||
case GGML_OP_SIN:
|
|
||||||
case GGML_OP_COS:
|
|
||||||
case GGML_OP_LOG:
|
|
||||||
case GGML_OP_CLAMP:
|
|
||||||
case GGML_OP_PAD:
|
|
||||||
case GGML_OP_ROLL:
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
case GGML_OP_SET_ROWS:
|
|
||||||
case GGML_OP_CONT:
|
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_SILU_BACK:
|
|
||||||
case GGML_OP_NORM:
|
|
||||||
case GGML_OP_GROUP_NORM:
|
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
|
||||||
case GGML_OP_L2_NORM:
|
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
|
||||||
case GGML_OP_SOFT_MAX:
|
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
|
||||||
case GGML_OP_ROPE:
|
|
||||||
case GGML_OP_ROPE_BACK:
|
|
||||||
case GGML_OP_MUL_MAT:
|
|
||||||
case GGML_OP_MUL_MAT_ID:
|
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_SUM_ROWS:
|
|
||||||
case GGML_OP_MEAN:
|
|
||||||
case GGML_OP_ARGMAX:
|
|
||||||
case GGML_OP_COUNT_EQUAL:
|
|
||||||
case GGML_OP_IM2COL:
|
|
||||||
case GGML_OP_IM2COL_3D:
|
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
||||||
case GGML_OP_POOL_2D:
|
|
||||||
case GGML_OP_CONV_2D:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
||||||
case GGML_OP_CONV_2D_DW:
|
|
||||||
case GGML_OP_RWKV_WKV6:
|
|
||||||
case GGML_OP_RWKV_WKV7:
|
|
||||||
case GGML_OP_SSM_SCAN:
|
|
||||||
case GGML_OP_SSM_CONV:
|
|
||||||
case GGML_OP_LEAKY_RELU:
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vk_context compute_ctx;
|
vk_context compute_ctx;
|
||||||
|
|
@ -11961,145 +11848,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
ctx->compute_ctx.reset();
|
ctx->compute_ctx.reset();
|
||||||
|
|
||||||
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
|
ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
|
||||||
if (!ok) {
|
|
||||||
if (node->op == GGML_OP_UNARY) {
|
|
||||||
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
|
||||||
} else if (node->op == GGML_OP_GLU) {
|
|
||||||
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
|
||||||
} else {
|
|
||||||
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
|
static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
|
||||||
GGML_UNUSED(cgraph);
|
GGML_UNUSED(cgraph);
|
||||||
ggml_backend_buffer * buf = nullptr;
|
GGML_UNUSED(tensor);
|
||||||
|
|
||||||
switch (tensor->op) {
|
|
||||||
case GGML_OP_ADD:
|
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_GET_ROWS:
|
|
||||||
case GGML_OP_SUB:
|
|
||||||
case GGML_OP_MUL:
|
|
||||||
case GGML_OP_DIV:
|
|
||||||
case GGML_OP_ADD1:
|
|
||||||
case GGML_OP_ARANGE:
|
|
||||||
case GGML_OP_FILL:
|
|
||||||
case GGML_OP_ADD_ID:
|
|
||||||
case GGML_OP_CONCAT:
|
|
||||||
case GGML_OP_UPSCALE:
|
|
||||||
case GGML_OP_SCALE:
|
|
||||||
case GGML_OP_SQR:
|
|
||||||
case GGML_OP_SQRT:
|
|
||||||
case GGML_OP_SIN:
|
|
||||||
case GGML_OP_COS:
|
|
||||||
case GGML_OP_LOG:
|
|
||||||
case GGML_OP_CLAMP:
|
|
||||||
case GGML_OP_PAD:
|
|
||||||
case GGML_OP_ROLL:
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
case GGML_OP_SET_ROWS:
|
|
||||||
case GGML_OP_CONT:
|
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_SILU_BACK:
|
|
||||||
case GGML_OP_NORM:
|
|
||||||
case GGML_OP_GROUP_NORM:
|
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
|
||||||
case GGML_OP_L2_NORM:
|
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
|
||||||
case GGML_OP_SOFT_MAX:
|
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
|
||||||
case GGML_OP_ROPE:
|
|
||||||
case GGML_OP_ROPE_BACK:
|
|
||||||
case GGML_OP_RESHAPE:
|
|
||||||
case GGML_OP_VIEW:
|
|
||||||
case GGML_OP_PERMUTE:
|
|
||||||
case GGML_OP_TRANSPOSE:
|
|
||||||
case GGML_OP_NONE:
|
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_SUM_ROWS:
|
|
||||||
case GGML_OP_MEAN:
|
|
||||||
case GGML_OP_ARGMAX:
|
|
||||||
case GGML_OP_COUNT_EQUAL:
|
|
||||||
case GGML_OP_IM2COL:
|
|
||||||
case GGML_OP_IM2COL_3D:
|
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
||||||
case GGML_OP_POOL_2D:
|
|
||||||
case GGML_OP_CONV_2D:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
||||||
case GGML_OP_CONV_2D_DW:
|
|
||||||
case GGML_OP_RWKV_WKV6:
|
|
||||||
case GGML_OP_RWKV_WKV7:
|
|
||||||
case GGML_OP_SSM_SCAN:
|
|
||||||
case GGML_OP_SSM_CONV:
|
|
||||||
case GGML_OP_LEAKY_RELU:
|
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_REPEAT_BACK:
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
break;
|
|
||||||
case GGML_OP_UNARY:
|
|
||||||
switch (ggml_get_unary_op(tensor)) {
|
|
||||||
case GGML_UNARY_OP_EXP:
|
|
||||||
case GGML_UNARY_OP_SILU:
|
|
||||||
case GGML_UNARY_OP_GELU:
|
|
||||||
case GGML_UNARY_OP_GELU_ERF:
|
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
|
||||||
case GGML_UNARY_OP_RELU:
|
|
||||||
case GGML_UNARY_OP_NEG:
|
|
||||||
case GGML_UNARY_OP_TANH:
|
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
|
||||||
case GGML_UNARY_OP_ABS:
|
|
||||||
case GGML_UNARY_OP_SOFTPLUS:
|
|
||||||
case GGML_UNARY_OP_STEP:
|
|
||||||
case GGML_UNARY_OP_ROUND:
|
|
||||||
case GGML_UNARY_OP_CEIL:
|
|
||||||
case GGML_UNARY_OP_FLOOR:
|
|
||||||
case GGML_UNARY_OP_TRUNC:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_GLU:
|
|
||||||
switch (ggml_get_glu_op(tensor)) {
|
|
||||||
case GGML_GLU_OP_GEGLU:
|
|
||||||
case GGML_GLU_OP_REGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU_OAI:
|
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_MUL_MAT:
|
|
||||||
case GGML_OP_MUL_MAT_ID:
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buf == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
|
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
|
||||||
|
|
||||||
|
|
@ -12143,8 +11899,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
subctx->out_memcpys.clear();
|
subctx->out_memcpys.clear();
|
||||||
subctx->memsets.clear();
|
subctx->memsets.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up after graph processing is done
|
// Clean up after graph processing is done
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue