Rename ne* to ne0* for consistent variable naming

This commit is contained in:
Oliver Simons 2026-02-04 18:00:43 +01:00
parent e4611f3b33
commit dac6e60c62
1 changed files with 71 additions and 71 deletions

View File

@ -43,9 +43,9 @@ static __device__ void rope_yarn(
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_norm(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -64,15 +64,15 @@ static __global__ void rope_norm(const T * x,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = row_dst / (ne1*ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
@ -115,9 +115,9 @@ static __global__ void rope_norm(const T * x,
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_neox(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -136,15 +136,15 @@ static __global__ void rope_neox(const T * x,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = row_dst / (ne1*ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
@ -182,9 +182,9 @@ static __global__ void rope_neox(const T * x,
template <bool forward, bool has_ff, typename T>
static __global__ void rope_multi(const T * x,
T * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -203,15 +203,15 @@ static __global__ void rope_multi(const T * x,
const bool is_imrope) {
const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = row_dst / (ne1*ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
@ -230,23 +230,23 @@ static __global__ void rope_multi(const T * x,
float theta_base = 0.0;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[i2 + ne2 * 1] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[i2 + ne2 * 2] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
} else {
theta_base = pos[i2 + ne2 * 3] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
} else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 2] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 3] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
}
}
@ -267,9 +267,9 @@ static __global__ void rope_multi(const T * x,
template <bool forward, bool has_ff, typename T>
static __global__ void rope_vision(const T * x,
T * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -287,15 +287,15 @@ static __global__ void rope_vision(const T * x,
const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = row_dst / (ne1 * ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 / 2 + +i1 * nb01 + i2 * nb02 + i3 * nb03;
@ -310,7 +310,7 @@ static __global__ void rope_vision(const T * x,
theta_base = pos[i2] * powf(theta_scale, p);
} else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[i2 + ne2] * powf(theta_scale, p);
theta_base = pos[i2 + ne02] * powf(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -330,9 +330,9 @@ static __global__ void rope_vision(const T * x,
template <bool forward, typename T, typename D>
static void rope_norm_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -351,30 +351,30 @@ static void rope_norm_cuda(const T * x,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
}
template <bool forward, typename T, typename D>
static void rope_neox_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -393,30 +393,30 @@ static void rope_neox_cuda(const T * x,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
}
template <bool forward, typename T>
static void rope_multi_cuda(const T * x,
T * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -435,30 +435,30 @@ static void rope_multi_cuda(const T * x,
const mrope_sections sections,
const bool is_imrope,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
template <bool forward, typename T>
static void rope_vision_cuda(const T * x,
T * dst,
const int ne0,
const int ne1,
const int ne2,
const int ne00,
const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
@ -476,9 +476,9 @@ static void rope_vision_cuda(const T * x,
const float * freq_factors,
const mrope_sections sections,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@ -487,12 +487,12 @@ static void rope_vision_cuda(const T * x,
if (freq_factors == nullptr) {
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections);
x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}