cuda : add FILL op support (#17851)
* cuda : add FILL op support * cuda : add missing FILL op files
This commit is contained in:
parent
37a4f63244
commit
51e0c2d917
|
|
@ -0,0 +1,37 @@
|
||||||
|
#include "fill.cuh"
|
||||||
|
#include "convert.cuh"
|
||||||
|
|
||||||
|
#define CUDA_FILL_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) {
|
||||||
|
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
void * dst_d = dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
|
float value;
|
||||||
|
memcpy(&value, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
const int64_t k = ggml_nelements(dst);
|
||||||
|
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
|
||||||
|
|
||||||
|
switch (dst->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("unsupported type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
@ -56,6 +56,7 @@
|
||||||
#include "ggml-cuda/solve_tri.cuh"
|
#include "ggml-cuda/solve_tri.cuh"
|
||||||
#include "ggml-cuda/tri.cuh"
|
#include "ggml-cuda/tri.cuh"
|
||||||
#include "ggml-cuda/cumsum.cuh"
|
#include "ggml-cuda/cumsum.cuh"
|
||||||
|
#include "ggml-cuda/fill.cuh"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_SOLVE_TRI:
|
case GGML_OP_SOLVE_TRI:
|
||||||
ggml_cuda_op_solve_tri(ctx, dst);
|
ggml_cuda_op_solve_tri(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_FILL:
|
||||||
|
ggml_cuda_op_fill(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -4617,6 +4621,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
|
case GGML_OP_FILL:
|
||||||
case GGML_OP_CUMSUM:
|
case GGML_OP_CUMSUM:
|
||||||
case GGML_OP_TRI:
|
case GGML_OP_TRI:
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue