ggml-blas: code clean up
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
parent
19c8ec9964
commit
1926e07e1a
|
|
@ -27,13 +27,16 @@ struct ggml_backend_blas_context {
|
|||
#endif
|
||||
};
|
||||
|
||||
static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
static void ggml_backend_blas_mul_mat(
|
||||
ggml_backend_blas_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
const ggml_type type = src0->type;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
|
|
@ -84,7 +87,7 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
|
|||
}
|
||||
#else
|
||||
for (int i = 1; i < n_threads; i++) {
|
||||
const int64_t start = i*ne01/n_threads;
|
||||
const int64_t start = (i + 0) * ne01/n_threads;
|
||||
const int64_t end = (i + 1) * ne01/n_threads;
|
||||
if (start < end) {
|
||||
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
|
||||
|
|
@ -149,14 +152,17 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * ids = dst->src[2];
|
||||
static void ggml_backend_blas_mul_mat_id(
|
||||
ggml_backend_blas_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * ids = dst->src[2];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
const ggml_type type = src0->type;
|
||||
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
|
@ -173,15 +179,10 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
|
|||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
GGML_UNUSED(r2);
|
||||
GGML_UNUSED(r3);
|
||||
|
||||
const int64_t ne_plane = ne01*ne00;
|
||||
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
|
||||
const size_t desired_wsize = type == GGML_TYPE_F32
|
||||
? 0
|
||||
: ne03*ne02*ne_plane*sizeof(float);
|
||||
|
||||
if (ctx->work_size < desired_wsize) {
|
||||
ctx->work_data.reset(new char[desired_wsize]);
|
||||
|
|
@ -210,7 +211,7 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
|
|||
}
|
||||
#else
|
||||
for (int i = 1; i < n_threads; i++) {
|
||||
const int64_t start = i*ne01/n_threads;
|
||||
const int64_t start = (i + 0)*ne01/n_threads;
|
||||
const int64_t end = (i + 1)*ne01/n_threads;
|
||||
if (start < end) {
|
||||
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
|
||||
|
|
@ -555,15 +556,13 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
|
|||
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
|
||||
// GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n",
|
||||
// __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type),
|
||||
// op->ne[0], op->ne[1], op->ne[2], op->ne[3]);
|
||||
|
||||
return src2->type == GGML_TYPE_I32;
|
||||
return ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
|
||||
}
|
||||
|
||||
case GGML_OP_OUT_PROD:
|
||||
|
|
|
|||
Loading…
Reference in New Issue