reverted to original

This commit is contained in:
Yee Man Chan 2026-01-17 07:43:30 +08:00
parent 0aea18e718
commit f3d118d061
1 changed files with 64 additions and 47 deletions

View File

@ -7,10 +7,9 @@
#include "unary-ops.h"
#include "vec.h"
#include <cfloat>
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <functional>
// ggml_compute_forward_dup
@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
}
}
// ggml_compute_forward_pool_1d_sk_p0
static void ggml_compute_forward_pool_1d_sk_p0(
// ggml_compute_forward_pool_1d_ksp
static void ggml_compute_forward_pool_1d_ksp(
const ggml_compute_params * params,
const ggml_op_pool op,
const int k,
const int s,
const int p,
ggml_tensor * dst) {
const ggml_tensor * src = dst->src[0];
@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
return;
}
const char * cdata = (const char *)src->data;
const char * const data_end = cdata + ggml_nbytes(src);
float * drow = (float *)dst->data;
const int64_t IW = src->ne[0];
const int64_t OW = dst->ne[0];
const int64_t rs = dst->ne[0];
const int64_t nr = ggml_nrows(src);
while (cdata < data_end) {
const void * srow = (const void *)cdata;
int j = 0;
for (int64_t i = 0; i < rs; ++i) {
for (int64_t ir = 0; ir < nr; ++ir) {
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
for (int64_t ow = 0; ow < OW; ++ow) {
float res = 0;
switch (op) {
case GGML_OP_POOL_AVG: drow[i] = 0; break;
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
case GGML_OP_POOL_AVG: res = 0.0f; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
int count = 0;
const int base = (int) ow * s - p;
for (int ki = 0; ki < k; ++ki) {
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
switch (op) {
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
const int j = base + ki;
if (j < 0 || j >= (int) IW) {
continue;
}
++j;
float v;
if (src->type == GGML_TYPE_F32) {
v = ((const float *) srow_bytes)[j];
} else {
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
}
switch (op) {
case GGML_OP_POOL_AVG: res += v; break;
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
++count;
}
switch (op) {
case GGML_OP_POOL_AVG: drow[i] /= k; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
}
cdata += src->nb[1];
drow += rs;
drow[ow] = res;
}
}
}
@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
const int k0 = opts[1];
const int s0 = opts[2];
const int p0 = opts[3];
GGML_ASSERT(p0 == 0); // padding not supported
GGML_ASSERT(k0 == s0); // only s = k supported
ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
}
// ggml_compute_forward_pool_2d
@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
}
const int32_t * opts = (const int32_t *)dst->op_params;
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d(
while (cdata < data_end) {
for (int oy = 0; oy < py; ++oy) {
float * const drow = dplane + oy * px;
float * const out = drow;
for (int ox = 0; ox < px; ++ox) {
float * const out = drow + ox;
float res = 0;
switch (op) {
case GGML_OP_POOL_AVG: *out = 0; break;
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d(
const int iy = offset1 + oy * s1;
for (int ky = 0; ky < k1; ++ky) {
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
continue;
}
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
for (int kx = 0; kx < k0; ++kx) {
int j = ix + kx;
if (j < 0 || j >= src->ne[0]) continue;
if (j < 0 || j >= src->ne[0]) {
continue;
}
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
switch (op) {
case GGML_OP_POOL_AVG: *out += srow_j; break;
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
case GGML_OP_POOL_AVG: res += srow_j; break;
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
}
}
switch (op) {
case GGML_OP_POOL_AVG: *out /= ka; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_AVG: res /= ka; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
out[ox] = res;
}
}
@ -8713,8 +8739,6 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3);
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
// {d_conv - 1 + n_t, d_inner, n_seqs}
@ -8735,13 +8759,6 @@ static void ggml_compute_forward_ssm_conv_f32(
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
}
x[i1] = sumf;
// Debug output
if (do_conv_debug && i1 == 0 && i2 == 0 && i3 == 0) {
fprintf(stderr, "DEBUG SSM_CONV: nc=%d, nr=%d, n_t=%d, n_s=%d\n", nc, nr, n_t, n_s);
fprintf(stderr, "DEBUG SSM_CONV: s[0..3]=%f,%f,%f,%f, c[0..3]=%f,%f,%f,%f, x[0]=%f\n",
s[0], s[1], s[2], s[3], c[0], c[1], c[2], c[3], x[0]);
}
}
}
}