cuda: add support for non-contig q,k,v
This commit is contained in:
parent
15d83e0c87
commit
2f0ac21d4b
|
|
@ -1,31 +1,42 @@
|
|||
#include "ggml-cuda/common.cuh"
|
||||
#include "gated_delta_net.cuh"
|
||||
#include "ggml-cuda/common.cuh"
|
||||
|
||||
template<int S_v>
|
||||
__global__ void gated_delta_net_cuda(
|
||||
const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs
|
||||
) {
|
||||
template <int S_v>
|
||||
__global__ void gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs,
|
||||
int64_t sq1,
|
||||
int64_t sq2,
|
||||
int64_t sq3,
|
||||
int64_t sv1,
|
||||
int64_t sv2,
|
||||
int64_t sv3,
|
||||
int64_t sg1,
|
||||
int64_t sg2,
|
||||
int64_t rq1,
|
||||
int64_t rq3) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
|
||||
const int64_t iq1 = h_idx / rq1;
|
||||
const int64_t iq3 = sequence / rq3;
|
||||
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
state += state_offset;
|
||||
state += state_offset;
|
||||
curr_state += state_offset;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
// Copy input state to output state (working area)
|
||||
#pragma unroll
|
||||
|
|
@ -34,14 +45,15 @@ __global__ void gated_delta_net_cuda(
|
|||
}
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v;
|
||||
const float * q_t = q + qkv_offset;
|
||||
const float * k_t = k + qkv_offset;
|
||||
const float * v_t = v + qkv_offset;
|
||||
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
|
||||
|
||||
const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx;
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta[gb_offset]));
|
||||
const float g_val = expf(g[gb_offset]);
|
||||
const float * g_t = g + sequence * sg2 + t * sg1;
|
||||
const float * beta_t = beta + sequence * sg2 + t * sg1;
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_t[h_idx]));
|
||||
const float g_val = expf(g_t[h_idx]);
|
||||
|
||||
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
|
||||
float kv_col = 0.0f;
|
||||
|
|
@ -70,9 +82,7 @@ __global__ void gated_delta_net_cuda(
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
|
|
@ -80,42 +90,68 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx,
|
|||
ggml_tensor * src_beta = dst->src[4];
|
||||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
const int64_t S_v = src_q->ne[0];
|
||||
const int64_t H = src_q->ne[1];
|
||||
const int64_t n_tokens = src_q->ne[2];
|
||||
const int64_t n_seqs = src_q->ne[3];
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
||||
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
|
||||
|
||||
const int64_t S_v = nev0;
|
||||
const int64_t H = nev1;
|
||||
const int64_t n_tokens = nev2;
|
||||
const int64_t n_seqs = nev3;
|
||||
|
||||
const int64_t rq1 = nev1 / neq1;
|
||||
const int64_t rq3 = nev3 / neq3;
|
||||
|
||||
const float * q_d = (const float *) src_q->data;
|
||||
const float * k_d = (const float *) src_k->data;
|
||||
const float * v_d = (const float *) src_v->data;
|
||||
const float * g_d = (const float *) src_g->data;
|
||||
|
||||
const float * b_d = (const float *) src_beta->data;
|
||||
const float * s_d = (const float *) src_state->data;
|
||||
|
||||
float * dst_d = (float *) dst->data;
|
||||
const float * s_d = (const float *) src_state->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src_q));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_k));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_v));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
|
||||
GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
|
||||
GGML_ASSERT(ggml_are_same_stride(src_g, src_beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_g));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_state));
|
||||
|
||||
// strides in floats
|
||||
const int64_t sq1 = nbq1 / sizeof(float);
|
||||
const int64_t sq2 = nbq2 / sizeof(float);
|
||||
const int64_t sq3 = nbq3 / sizeof(float);
|
||||
const int64_t sv1 = nbv1 / sizeof(float);
|
||||
const int64_t sv2 = nbv2 / sizeof(float);
|
||||
const int64_t sv3 = nbv3 / sizeof(float);
|
||||
const int64_t sg1 = nbg1 / sizeof(float);
|
||||
const int64_t sg2 = nbg2 / sizeof(float);
|
||||
|
||||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
switch(S_v) {
|
||||
switch (S_v) {
|
||||
case 32:
|
||||
gated_delta_net_cuda<32><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
|
||||
gated_delta_net_cuda<32><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
|
||||
sv3, sg1, sg2, rq1, rq3);
|
||||
break;
|
||||
case 64:
|
||||
gated_delta_net_cuda<64><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
|
||||
gated_delta_net_cuda<64><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
|
||||
sv3, sg1, sg2, rq1, rq3);
|
||||
break;
|
||||
case 128:
|
||||
gated_delta_net_cuda<128><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
|
||||
gated_delta_net_cuda<128><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
|
||||
sv3, sg1, sg2, rq1, rq3);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
Loading…
Reference in New Issue