diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 566e271479..547ccc42aa 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -557,6 +557,7 @@ extern "C" { GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, GGML_OP_GATED_DELTA_NET, + GGML_OP_HADAMARD, GGML_OP_UNARY, @@ -2473,6 +2474,11 @@ extern "C" { struct ggml_tensor * beta, struct ggml_tensor * state); + GGML_API struct ggml_tensor * ggml_hadamard( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index dc2b5ffaa7..bed01ae65c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2025,6 +2025,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_gated_delta_net(params, tensor); } break; + case GGML_OP_HADAMARD: + { + ggml_compute_forward_hadamard(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2347,6 +2351,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: + case GGML_OP_HADAMARD: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 331e071a26..111a474a6f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11165,3 +11165,94 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_ } } } + +// ggml_compute_forward_hadamard + +// Based on a source code from: https://github.com/ikawrakow/ik_llama.cpp +// Copyright (C) 2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#include +#include +#include +#include +#include +inline int popcount(uint32_t x) { return __popcnt(x); } +#else +inline int popcount(uint32_t x) { return __builtin_popcount(x); } +#endif + +template +void fast_ht(int n, T * values) { + constexpr float ksqrt2 = 0.707106781f; + float scale = 1; + for (int h = 1; h < n; h <<= 1) { + for (int i = 0; i < n; i += 2*h) { + for (int j = i; j < i + h; ++j) { + T x = values[j], y = values[j + h]; + values[j+0] = x + y; + values[j+h] = x - y; + } + } + scale *= ksqrt2; + } + for (int i = 0; i < n; ++i) values[i] *= scale; +} + +static void ggml_compute_forward_hadamard_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + int nh = dst->op_params[0]; + GGML_ASSERT(nh > 1 && popcount(uint32_t(nh)) == 1); + GGML_ASSERT(dst->ne[0] % nh == 0); + + int nc = dst->ne[0]/nh; + int nr = ggml_nrows(dst) * nc; + + int npt = (nr + nth - 1)/nth; + int first = npt*ith; + int last = std::min(first + npt, nr); + + for (int ir = first; ir < last; ++ir) { + int i3 = ir / (dst->ne[1] * dst->ne[2] * nc); + int i2 = (ir - i3*dst->ne[1] * dst->ne[2] * nc)/(dst->ne[1] * nc); + int i1 = (ir - i3*dst->ne[1] * dst->ne[2] * nc - i2*dst->ne[1]*nc)/nc; + int ic = (ir - i3*dst->ne[1] * dst->ne[2] * nc - i2*dst->ne[1]*nc - i1*nc); + + auto x = (const float *)((const char *)src0->data + i3*src0->nb[3] + i2*src0->nb[2] + i1*src0->nb[1]) + ic*nh; + auto y = ( float *)(( char *)dst->data + i3*dst->nb[3] + i2*dst->nb[2] + i1*dst->nb[1]) + ic*nh; + memcpy(y, x, nh*sizeof(float)); + fast_ht(nh, y); + } +} + +void ggml_compute_forward_hadamard( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_hadamard_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3fa1443abc..c28d32ea91 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -103,6 +103,7 @@ void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, s void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_hadamard(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aeafc395d7..a01ee49ee3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1032,6 +1032,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RWKV_WKV7", "SOLVE_TRI", "GATED_DELTA_NET", + "HADAMARD", "UNARY", @@ -1049,7 +1050,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1142,6 +1143,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", "gated_delta_net(q, k, v, g, beta, s)", + "hadamard(x)", "unary(x)", @@ -1159,7 +1161,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6177,6 +6179,28 @@ struct ggml_tensor * ggml_gated_delta_net( return result; } +// ggml_hadamard + +struct ggml_tensor * ggml_hadamard( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n) { + + GGML_ASSERT(a->type == GGML_TYPE_F32); // will not bother implementing for other data types + GGML_ASSERT(n > 1); // no point in Hadamard transforms with less than 2 elements + GGML_ASSERT(a->ne[0] % n == 0); + GGML_ASSERT(n > 0 && ((n & (n - 1)) == 0)); // must be a power of 2 + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_HADAMARD; + result->src[0] = a; + + result->op_params[0] = n; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) {