simd_gemm: convert everything to int

This commit is contained in:
Aman Gupta 2026-02-13 18:34:48 +05:30
parent 8d1be6c4cd
commit 1b44835c2b
1 changed files with 7 additions and 7 deletions

View File

@ -29,8 +29,8 @@ static inline void simd_gemm_ukernel(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int64_t K, int64_t N,
int ii, int64_t jj)
int K, int N,
int ii, int jj)
{
static constexpr int KN = GGML_F32_EPR;
@ -66,7 +66,7 @@ static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int64_t M, int64_t K, int64_t N)
int M, int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
@ -92,7 +92,7 @@ static void simd_gemm(
// Tail rows: one at a time
for (; ii < M; ii++) {
int64_t jj = 0;
int jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj);
}
@ -115,12 +115,12 @@ static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int64_t M, int64_t K, int64_t N)
int M, int K, int N)
{
for (int i = 0; i < M; i++) {
for (int64_t j = 0; j < N; j++) {
for (int j = 0; j < N; j++) {
float sum = C[i * N + j];
for (int64_t kk = 0; kk < K; kk++) {
for (int kk = 0; kk < K; kk++) {
sum += A[i * K + kk] * B[kk * N + j];
}
C[i * N + j] = sum;