opencl: refine condition for kqv mm (#17392)
This commit is contained in:
parent
23bc779a6e
commit
8e9ddba610
|
|
@ -6895,9 +6895,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||||
cl_context context = backend_ctx->context;
|
cl_context context = backend_ctx->context;
|
||||||
|
|
||||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
|
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
|
||||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
// For KQ
|
||||||
return;
|
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
|
||||||
|
nb00 <= nb02 &&
|
||||||
|
nb02 <= nb01 &&
|
||||||
|
nb01 <= nb03 &&
|
||||||
|
nb10 <= nb12 &&
|
||||||
|
nb12 <= nb11 &&
|
||||||
|
nb11 <= nb13) {
|
||||||
|
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// For KQV
|
||||||
|
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||||
|
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue