simd_gemm: convert everything to int
This commit is contained in:
parent
8d1be6c4cd
commit
1b44835c2b
|
|
@ -29,8 +29,8 @@ static inline void simd_gemm_ukernel(
|
|||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int64_t K, int64_t N,
|
||||
int ii, int64_t jj)
|
||||
int K, int N,
|
||||
int ii, int jj)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ static void simd_gemm(
|
|||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int64_t M, int64_t K, int64_t N)
|
||||
int M, int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ static void simd_gemm(
|
|||
|
||||
// Tail rows: one at a time
|
||||
for (; ii < M; ii++) {
|
||||
int64_t jj = 0;
|
||||
int jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj);
|
||||
}
|
||||
|
|
@ -115,12 +115,12 @@ static void simd_gemm(
|
|||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int64_t M, int64_t K, int64_t N)
|
||||
int M, int K, int N)
|
||||
{
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
float sum = C[i * N + j];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
for (int kk = 0; kk < K; kk++) {
|
||||
sum += A[i * K + kk] * B[kk * N + j];
|
||||
}
|
||||
C[i * N + j] = sum;
|
||||
|
|
|
|||
Loading…
Reference in New Issue