diff --git a/docs/ops.md b/docs/ops.md index bd26c0eb45..5df72d2501 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,6 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -41,6 +42,7 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | +| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -82,6 +84,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | +| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -108,5 +111,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/CPU.csv b/docs/ops/CPU.csv index 21e0d1b3c9..1820028c9a 100644 --- a/docs/ops/CPU.csv +++ b/docs/ops/CPU.csv @@ -59,6 +59,14 @@ "CPU","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" +"CPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" "CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" "CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" "CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" @@ -119,6 +127,14 @@ "CPU","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" +"CPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" "CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=0","support","1","yes","CPU" "CPU","REGLU","type=f16,ne_a=[5,7,11,13],v=0,swapped=0","support","1","yes","CPU" "CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=1","support","1","yes","CPU" diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 60c6b63d05..d948b00cc7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -577,6 +577,10 @@ extern "C" { GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_XIELU, + GGML_UNARY_OP_FLOOR, + GGML_UNARY_OP_CEIL, + GGML_UNARY_OP_ROUND, + GGML_UNARY_OP_TRUNC, GGML_UNARY_OP_COUNT, }; @@ -1151,6 +1155,46 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + /** + * Truncates the fractional part of each element in the tensor (towards zero). + * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0 + * Similar to std::trunc in C/C++. + */ + + GGML_API struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + + // xIELU activation function // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ba2a36d999..29c870600b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2184,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 1c43865ff6..b52f0f8472 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_FLOOR: + { + ggml_compute_forward_floor(params, dst); + } break; + case GGML_UNARY_OP_CEIL: + { + ggml_compute_forward_ceil(params, dst); + } break; + case GGML_UNARY_OP_ROUND: + { + ggml_compute_forward_round(params, dst); + } break; + case GGML_UNARY_OP_TRUNC: + { + ggml_compute_forward_trunc(params, dst); + } break; case GGML_UNARY_OP_XIELU: { ggml_compute_forward_xielu(params, dst); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index cf1a4615d0..a047537b34 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -73,6 +73,22 @@ static inline float op_log(float x) { return logf(x); } +static inline float op_floor(float x) { + return floorf(x); +} + +static inline float op_ceil(float x) { + return ceilf(x); +} + +static inline float op_round(float x) { + return roundf(x); +} + +static inline float op_trunc(float x) { + return truncf(x); +} + template static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { constexpr auto src0_to_f32 = type_conversion_table::to_f32; @@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * unary_op(params, dst); } +void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { const float alpha_n = ggml_get_op_params_f32(dst, 1); const float alpha_p = ggml_get_op_params_f32(dst, 2); diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index 697c1e0da0..fa45d9f0e6 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5efd3e5f4e..9f6a0500aa 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -273,6 +273,15 @@ static ggml_cuda_device_info ggml_cuda_init() { } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") { turing_devices_without_mma.push_back({ id, device_name }); } + + // Temporary performance fix: + // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls. + // TODO: Check for future drivers the default scheduling strategy and + // remove this call again when cudaDeviceScheduleSpin is default. + if (prop.major == 12 && prop.minor == 1) { + CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); + } + #endif // defined(GGML_USE_HIP) } @@ -3644,9 +3653,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_SUM: case GGML_OP_ACC: return true; + case GGML_OP_SUM: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGSORT: // TODO: Support arbitrary column width return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c3fe8f4e91..c3c83abe4e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -7,6 +7,8 @@ #include +#include + #ifndef TARGET_OS_VISION #define TARGET_OS_VISION 0 #endif @@ -22,6 +24,9 @@ // overload of MTLGPUFamilyMetal3 (not available in some environments) static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; +// virtual address for GPU memory allocations +static atomic_uintptr_t g_addr_device = 0x000000400ULL; + #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -657,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: @@ -827,7 +833,7 @@ struct ggml_metal_buffer_wrapper { }; struct ggml_metal_buffer { - void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985 + void * all_data; size_t all_size; // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host @@ -965,14 +971,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, if (shared) { res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; - res->owned = true; } else { - // dummy, non-NULL value - we'll populate this after creating the Metal buffer below - res->all_data = (void *) 0x000000400ULL; + // use virtual address from g_addr_device counter + res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; + res->owned = true; + res->device = ggml_metal_device_get_obj(dev); res->queue = ggml_metal_device_get_queue(dev); @@ -983,15 +990,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->buffers[0].metal = nil; if (size_aligned > 0) { - if (props_dev->use_shared_buffers &&shared) { + if (props_dev->use_shared_buffers && shared) { res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; } else { res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; - - res->all_data = (void *) (res->buffers[0].metal.gpuAddress); } } @@ -1139,7 +1144,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) { void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { if (buf->is_shared) { - memset((char *)tensor->data + offset, value, size); + memset((char *) tensor->data + offset, value, size); return; } @@ -1168,7 +1173,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { if (buf->is_shared) { - memcpy((char *)tensor->data + offset, data, size); + memcpy((char *) tensor->data + offset, data, size); return; } @@ -1223,7 +1228,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { if (buf->is_shared) { - memcpy(data, (const char *)tensor->data + offset, size); + memcpy(data, (const char *) tensor->data + offset, size); return; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a448c14f66..fa2d82cefb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -251,6 +251,7 @@ typedef struct { int32_t sect_1; int32_t sect_2; int32_t sect_3; + bool src2; } ggml_metal_kargs_rope; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a61ea8fb5a..4f9f6bda00 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + int nth = 32; // SIMD width + + while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, (int) n); + + const int nsg = (nth + 31) / 32; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); return 1; } @@ -2969,6 +2982,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* sect_1 =*/ sect_1, /* sect_2 =*/ sect_2, /* sect_3 =*/ sect_3, + /* src2 =*/ op->src[2] != nullptr, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 1029cf8f9a..496610b154 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, device float * dst, - ushort tiitg[[thread_index_in_threadgroup]]) { + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - if (tiitg != 0) { + if (args.np == 0) { return; } - float acc = 0.0f; - for (ulong i = 0; i < args.np; ++i) { - acc += src0[i]; + const uint nsg = (ntg.x + 31) / 32; + + float sumf = 0; + + for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { + sumf += src0[i0]; } - dst[0] = acc; + sumf = simd_sum(sumf); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float total = 0; + + if (sgitg == 0) { + float v = 0; + + if (tpitg.x < nsg) { + v = shmem_f32[tpitg.x]; + } + + total = simd_sum(v); + + if (tpitg.x == 0) { + dst[0] = total; + } + } } template @@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3872,7 +3902,7 @@ kernel void kernel_rope_multi( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3939,7 +3969,7 @@ kernel void kernel_rope_vision( const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); // end of mrope - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 7e6c843846..6f6bba55e2 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q8_0_f32_l4_lm mul norm relu diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0693d38d80..2ec896fd0e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -408,6 +408,7 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_id_mxfp4_f32_flat; cl_program program_mul_mm_f32_f32_l4_lm; cl_program program_mul_mm_f16_f32_l4_lm; + cl_program program_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16; cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16; @@ -480,6 +481,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; std::vector profiling_info; @@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q8_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q8_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl"); +#endif + backend_ctx->program_mul_mm_q8_0_f32_l4_lm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q8_0: { + if (ne11 < 32) { + break; + } + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index 9c0bab135a..a6d7479037 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -4,6 +4,7 @@ #define ACC_TYPE4 float4 #define DATA_TYPE float #define DATA_TYPE4 float4 +#define MASK_DATA_TYPE half #define CONVERT_ACC4(x) (x) #define CONVERT_DATA4(x) (x) @@ -148,7 +149,7 @@ __kernel void flash_attn_f32( if (k_row1 >= n_kv) score1 = -INFINITY; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; } @@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1( } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); score += slope * (ACC_TYPE)mask_ptr[k_idx]; } if (logit_softcap > 0.0f) { @@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1( } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); score += slope * (ACC_TYPE)mask_ptr[k_idx]; } if (logit_softcap > 0.0f) { diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl index 9599a0e157..1a1bfe144f 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl @@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm( for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { + if (loadc_a + l < ne01) { const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; - buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; - buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; - buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; - buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h; + } } for (int l = 0; l < BN; l += loadstride_b) { - const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; - buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; - buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; - buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; - buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + if (loadc_b + l < ne11) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h; + } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl index 58c5178e39..39a5d4868f 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl @@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm( for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { - const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; - buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; - buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; - buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; - buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + if (loadc_a + l < ne01) { + const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } } for (int l = 0; l < BN; l += loadstride_b) { - const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; - buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; - buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; - buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; - buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + if (loadc_b + l < ne11) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl new file mode 100644 index 0000000000..fd47e8a89d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl @@ -0,0 +1,154 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q8_0_f32_l4_lm( + global char4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 8; + int iqs = idx % 8; + + float d = (float)src0_d[ib]; + global char4 * qs = src0_q + ib*8 + iqs; + char4 q = *qs; + float4 v = convert_float4(q)*d; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bce1375ba..86f1c31afd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1144,9 +1144,13 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "EXP", "GELU_ERF", "XIELU", + "FLOOR", + "CEIL", + "ROUND", + "TRUNC", }; -static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); +static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2749,6 +2753,62 @@ static struct ggml_tensor * ggml_glu_impl( return result; } +// ggml_floor + +struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR); +} + +struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR); +} + +// ggml_ceil + +struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL); +} + +struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL); +} + +//ggml_round + +struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND); +} + +struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND); +} + +//ggml_trunc + +struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC); +} + +struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC); +} + struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py index 211a3f536a..0bda490a20 100755 --- a/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -91,6 +91,7 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None tensor.tensor_type not in ( gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, + gguf.GGMLQuantizationType.BF16, ): raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") logger.info(f"* Preparing to convert from {file_endian} to {order}") @@ -148,6 +149,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None # restore old shape in case it's ever used tensor.data.resize(oldshape) + elif tensor.tensor_type == gguf.GGMLQuantizationType.BF16: + # Special case for BF16 + # It is 2-bytes data, but by default view loads it as 1-byte data. + # Change to correct view before byteswapping. + tensor.data.view(dtype=np.uint16).byteswap(inplace=True) else: # Handle other tensor types tensor.data.byteswap(inplace=True) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 869e4dccf0..b7e00b275b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -5,6 +5,7 @@ #include static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize { LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_DECI, "deci" }, @@ -275,6 +276,10 @@ static const std::map LLM_KV_NAMES = { }; static const std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_CLIP, + {}, + }, { LLM_ARCH_LLAMA, { diff --git a/src/llama-arch.h b/src/llama-arch.h index c3ae71655b..c41de89859 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -9,6 +9,7 @@ // enum llm_arch { + LLM_ARCH_CLIP, LLM_ARCH_LLAMA, LLM_ARCH_LLAMA4, LLM_ARCH_DECI, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0cdad9babd..5002bd42ff 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -478,7 +478,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_GENERAL_NAME, name, false); // everything past this point is not vocab-related - if (hparams.vocab_only) { + // for CLIP models, we only need to load tensors, no hparams + if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) { return; } @@ -20013,6 +20014,7 @@ int32_t llama_n_head(const llama_model * model) { llama_rope_type llama_model_rope_type(const llama_model * model) { switch (model->arch) { // these models do not use RoPE + case LLM_ARCH_CLIP: case LLM_ARCH_GPT2: case LLM_ARCH_GPTJ: case LLM_ARCH_MPT: diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 97228b2a69..6dd40412b4 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } + bool is_clip_model = false; for (const auto * it : tensors) { const struct ggml_tensor * tensor = it->tensor; @@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; } + + is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix } qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks for models that have attention layers - if (qs.n_attention_wv != 0) + if (qs.n_attention_wv != 0 && !is_clip_model) { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads @@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; + // do not quantize specific multimodal tensors + quantize &= name.find(".position_embd.") == std::string::npos; + ggml_type new_type; void * new_data; size_t new_size; diff --git a/src/llama.cpp b/src/llama.cpp index 38700f97a0..ab2e9868af 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector } catch(const std::exception & e) { throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); } + if (model.arch == LLM_ARCH_CLIP) { + throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); + } try { model.load_vocab(ml); } catch(const std::exception & e) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 14472dcf12..d5c5a2a665 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case { struct test_sum : public test_case { const ggml_type type; const std::array ne; + const std::array permute; + bool _use_permute; std::string vars() override { - return VARS_TO_STR2(type, ne); + std::string v = VARS_TO_STR2(type, ne); + if (_use_permute) v += "," + VAR_TO_STR(permute); + return v; } test_sum(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 5, 4, 3}) - : type(type), ne(ne) {} + std::array ne = {10, 5, 4, 3}, + std::array permute = {0, 0, 0, 0}) + : type(type), ne(ne), permute(permute), + _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); + if (_use_permute) { + a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]); + ggml_set_name(a, "a_permuted"); + } + ggml_tensor * out = ggml_sum(ctx, a); ggml_set_name(out, "out"); @@ -6354,6 +6365,19 @@ static std::vector> make_test_cases_eval() { } } +#if 0 + { + // Test paths in OpenCL + std::vector ns = {32, 64, 128, 256, 512, 1024, 4096}; + std::vector ks = {896, 1536, 4096}; + for (auto n : ns) { + for (auto k : ks) { + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1})); + } + } + } +#endif + #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { @@ -6724,6 +6748,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); @@ -6734,6 +6761,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index c026f36c48..1c62ebe968 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 77969d24e1..8737fba124 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3812,7 +3812,7 @@ struct server_context { if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } @@ -3839,14 +3839,14 @@ struct server_context { { const auto token = slot.prompt.tokens[i]; - const auto piece = common_token_to_piece(ctx, token); + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; ss0 << piece; st0 << std::setw(8) << token; } { const auto token = slot.task->tokens[i]; - const auto piece = common_token_to_piece(ctx, token); + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; ss1 << piece; st1 << std::setw(8) << token; } @@ -3860,7 +3860,7 @@ struct server_context { } if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( @@ -4028,7 +4028,7 @@ struct server_context { } } - // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index fd0bc8de53..cc48f5a9d0 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1237,9 +1237,10 @@ public: // allowed to resize ^ ^ // disallowed to resize ^ ^ ^ if (n > 0) { - llama_token last_token = tokens[n - 1]; // make sure we never remove tokens in the middle of an image - if (last_token == LLAMA_TOKEN_NULL) { + // note that the case where we keep a full image at the end is allowed: + // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL + if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk } } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte index dc617afdcd..d5d4c7fe3f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte @@ -14,8 +14,7 @@ import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app'; import * as Dialog from '$lib/components/ui/dialog'; import { ScrollArea } from '$lib/components/ui/scroll-area'; - import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config'; - import { config, updateMultipleConfig, resetConfig } from '$lib/stores/settings.svelte'; + import { config, updateMultipleConfig } from '$lib/stores/settings.svelte'; import { setMode } from 'mode-watcher'; import type { Component } from 'svelte'; @@ -267,16 +266,13 @@ } function handleReset() { - resetConfig(); + localConfig = { ...config() }; - localConfig = { ...SETTING_CONFIG_DEFAULT }; - - setMode(SETTING_CONFIG_DEFAULT.theme as 'light' | 'dark' | 'system'); - originalTheme = SETTING_CONFIG_DEFAULT.theme as string; + setMode(localConfig.theme as 'light' | 'dark' | 'system'); + originalTheme = localConfig.theme as string; } function handleSave() { - // Validate custom JSON if provided if (localConfig.custom && typeof localConfig.custom === 'string' && localConfig.custom.trim()) { try { JSON.parse(localConfig.custom); diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index e06399e0bc..d17f7e4229 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -1,4 +1,5 @@ {#each fields as field (field.key)}
{#if field.type === 'input'} - + {@const paramInfo = getParameterSourceInfo(field.key)} + {@const currentValue = String(localConfig[field.key] ?? '')} + {@const propsDefault = paramInfo?.serverDefault} + {@const isCustomRealTime = (() => { + if (!paramInfo || propsDefault === undefined) return false; - onConfigChange(field.key, e.currentTarget.value)} - placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`} - class="w-full md:max-w-md" - /> + // Apply same rounding logic for real-time comparison + const inputValue = currentValue; + const numericInput = parseFloat(inputValue); + const normalizedInput = !isNaN(numericInput) + ? Math.round(numericInput * 1000000) / 1000000 + : inputValue; + const normalizedDefault = + typeof propsDefault === 'number' + ? Math.round(propsDefault * 1000000) / 1000000 + : propsDefault; + + return normalizedInput !== normalizedDefault; + })()} + +
+ + {#if isCustomRealTime} + + {/if} +
+ +
+ { + // Update local config immediately for real-time badge feedback + onConfigChange(field.key, e.currentTarget.value); + }} + placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`} + class="w-full {isCustomRealTime ? 'pr-8' : ''}" + /> + {#if isCustomRealTime} + + {/if} +
{#if field.help || SETTING_CONFIG_INFO[field.key]}

{field.help || SETTING_CONFIG_INFO[field.key]} @@ -59,14 +118,28 @@ (opt: { value: string; label: string; icon?: Component }) => opt.value === localConfig[field.key] )} + {@const paramInfo = getParameterSourceInfo(field.key)} + {@const currentValue = localConfig[field.key]} + {@const propsDefault = paramInfo?.serverDefault} + {@const isCustomRealTime = (() => { + if (!paramInfo || propsDefault === undefined) return false; - + // For select fields, do direct comparison (no rounding needed) + return currentValue !== propsDefault; + })()} + +

+ + {#if isCustomRealTime} + + {/if} +
{ if (field.key === 'theme' && value && onThemeChange) { onThemeChange(value); @@ -75,16 +148,34 @@ } }} > - -
- {#if selectedOption?.icon} - {@const IconComponent = selectedOption.icon} - - {/if} +
+ +
+ {#if selectedOption?.icon} + {@const IconComponent = selectedOption.icon} + + {/if} - {selectedOption?.label || `Select ${field.label.toLowerCase()}`} -
-
+ {selectedOption?.label || `Select ${field.label.toLowerCase()}`} +
+ + {#if isCustomRealTime} + + {/if} +
{#if field.options} {#each field.options as option (option.value)} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFooter.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFooter.svelte index 3408fe3ce4..4f2d978ab8 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFooter.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFooter.svelte @@ -1,6 +1,8 @@
- +
+ +
@@ -36,8 +46,9 @@ Reset Settings to Default - Are you sure you want to reset all settings to their default values? This action cannot be - undone and will permanently remove all your custom configurations. + Are you sure you want to reset all settings to their default values? This will reset all + parameters to the values provided by the server's /props endpoint and remove all your custom + configurations. diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ParameterSourceIndicator.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ParameterSourceIndicator.svelte new file mode 100644 index 0000000000..b566985ba0 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ParameterSourceIndicator.svelte @@ -0,0 +1,18 @@ + + + + + Custom + diff --git a/tools/server/webui/src/lib/components/app/index.ts b/tools/server/webui/src/lib/components/app/index.ts index 63a99f4343..4c2cbdebe1 100644 --- a/tools/server/webui/src/lib/components/app/index.ts +++ b/tools/server/webui/src/lib/components/app/index.ts @@ -25,6 +25,7 @@ export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte'; export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte'; export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte'; export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte'; +export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte'; export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte'; export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte'; diff --git a/tools/server/webui/src/lib/constants/precision.ts b/tools/server/webui/src/lib/constants/precision.ts new file mode 100644 index 0000000000..8df5c4f966 --- /dev/null +++ b/tools/server/webui/src/lib/constants/precision.ts @@ -0,0 +1,2 @@ +export const PRECISION_MULTIPLIER = 1000000; +export const PRECISION_DECIMAL_PLACES = 6; diff --git a/tools/server/webui/src/lib/services/parameter-sync.spec.ts b/tools/server/webui/src/lib/services/parameter-sync.spec.ts new file mode 100644 index 0000000000..9ced55faa0 --- /dev/null +++ b/tools/server/webui/src/lib/services/parameter-sync.spec.ts @@ -0,0 +1,135 @@ +import { describe, it, expect } from 'vitest'; +import { ParameterSyncService } from './parameter-sync'; +import type { ApiLlamaCppServerProps } from '$lib/types/api'; + +describe('ParameterSyncService', () => { + describe('roundFloatingPoint', () => { + it('should fix JavaScript floating-point precision issues', () => { + // Test the specific values from the screenshot + const mockServerParams = { + top_p: 0.949999988079071, + min_p: 0.009999999776482582, + temperature: 0.800000011920929, + top_k: 40, + samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature'] + }; + + const result = ParameterSyncService.extractServerDefaults({ + ...mockServerParams, + // Add other required fields to match the API type + n_predict: 512, + seed: -1, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typ_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + mirostat: 0, + mirostat_tau: 5.0, + mirostat_eta: 0.1, + stop: [], + max_tokens: -1, + n_keep: 0, + n_discard: 0, + ignore_eos: false, + stream: true, + logit_bias: [], + n_probs: 0, + min_keep: 0, + grammar: '', + grammar_lazy: false, + grammar_triggers: [], + preserved_tokens: [], + chat_format: '', + reasoning_format: '', + reasoning_in_content: false, + thinking_forced_open: false, + 'speculative.n_max': 0, + 'speculative.n_min': 0, + 'speculative.p_min': 0.0, + timings_per_token: false, + post_sampling_probs: false, + lora: [], + top_n_sigma: 0.0, + dry_sequence_breakers: [] + } as ApiLlamaCppServerProps['default_generation_settings']['params']); + + // Check that the problematic floating-point values are rounded correctly + expect(result.top_p).toBe(0.95); + expect(result.min_p).toBe(0.01); + expect(result.temperature).toBe(0.8); + expect(result.top_k).toBe(40); // Integer should remain unchanged + expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature'); + }); + + it('should preserve non-numeric values', () => { + const mockServerParams = { + samplers: ['top_k', 'temperature'], + max_tokens: -1, + temperature: 0.7 + }; + + const result = ParameterSyncService.extractServerDefaults({ + ...mockServerParams, + // Minimal required fields + n_predict: 512, + seed: -1, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + top_k: 40, + top_p: 0.95, + min_p: 0.05, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typ_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + mirostat: 0, + mirostat_tau: 5.0, + mirostat_eta: 0.1, + stop: [], + n_keep: 0, + n_discard: 0, + ignore_eos: false, + stream: true, + logit_bias: [], + n_probs: 0, + min_keep: 0, + grammar: '', + grammar_lazy: false, + grammar_triggers: [], + preserved_tokens: [], + chat_format: '', + reasoning_format: '', + reasoning_in_content: false, + thinking_forced_open: false, + 'speculative.n_max': 0, + 'speculative.n_min': 0, + 'speculative.p_min': 0.0, + timings_per_token: false, + post_sampling_probs: false, + lora: [], + top_n_sigma: 0.0, + dry_sequence_breakers: [] + } as ApiLlamaCppServerProps['default_generation_settings']['params']); + + expect(result.samplers).toBe('top_k;temperature'); + expect(result.max_tokens).toBe(-1); + expect(result.temperature).toBe(0.7); + }); + }); +}); diff --git a/tools/server/webui/src/lib/services/parameter-sync.ts b/tools/server/webui/src/lib/services/parameter-sync.ts new file mode 100644 index 0000000000..ee147ae194 --- /dev/null +++ b/tools/server/webui/src/lib/services/parameter-sync.ts @@ -0,0 +1,202 @@ +/** + * ParameterSyncService - Handles synchronization between server defaults and user settings + * + * This service manages the complex logic of merging server-provided default parameters + * with user-configured overrides, ensuring the UI reflects the actual server state + * while preserving user customizations. + * + * **Key Responsibilities:** + * - Extract syncable parameters from server props + * - Merge server defaults with user overrides + * - Track parameter sources (server, user, default) + * - Provide sync utilities for settings store integration + */ + +import type { ApiLlamaCppServerProps } from '$lib/types/api'; +import { normalizeFloatingPoint } from '$lib/utils/precision'; + +export type ParameterSource = 'default' | 'custom'; +export type ParameterValue = string | number | boolean; +export type ParameterRecord = Record; + +export interface ParameterInfo { + value: string | number | boolean; + source: ParameterSource; + serverDefault?: string | number | boolean; + userOverride?: string | number | boolean; +} + +export interface SyncableParameter { + key: string; + serverKey: string; + type: 'number' | 'string' | 'boolean'; + canSync: boolean; +} + +/** + * Mapping of webui setting keys to server parameter keys + * Only parameters that should be synced from server are included + */ +export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ + { key: 'temperature', serverKey: 'temperature', type: 'number', canSync: true }, + { key: 'top_k', serverKey: 'top_k', type: 'number', canSync: true }, + { key: 'top_p', serverKey: 'top_p', type: 'number', canSync: true }, + { key: 'min_p', serverKey: 'min_p', type: 'number', canSync: true }, + { key: 'dynatemp_range', serverKey: 'dynatemp_range', type: 'number', canSync: true }, + { key: 'dynatemp_exponent', serverKey: 'dynatemp_exponent', type: 'number', canSync: true }, + { key: 'xtc_probability', serverKey: 'xtc_probability', type: 'number', canSync: true }, + { key: 'xtc_threshold', serverKey: 'xtc_threshold', type: 'number', canSync: true }, + { key: 'typ_p', serverKey: 'typ_p', type: 'number', canSync: true }, + { key: 'repeat_last_n', serverKey: 'repeat_last_n', type: 'number', canSync: true }, + { key: 'repeat_penalty', serverKey: 'repeat_penalty', type: 'number', canSync: true }, + { key: 'presence_penalty', serverKey: 'presence_penalty', type: 'number', canSync: true }, + { key: 'frequency_penalty', serverKey: 'frequency_penalty', type: 'number', canSync: true }, + { key: 'dry_multiplier', serverKey: 'dry_multiplier', type: 'number', canSync: true }, + { key: 'dry_base', serverKey: 'dry_base', type: 'number', canSync: true }, + { key: 'dry_allowed_length', serverKey: 'dry_allowed_length', type: 'number', canSync: true }, + { key: 'dry_penalty_last_n', serverKey: 'dry_penalty_last_n', type: 'number', canSync: true }, + { key: 'max_tokens', serverKey: 'max_tokens', type: 'number', canSync: true }, + { key: 'samplers', serverKey: 'samplers', type: 'string', canSync: true } +]; + +export class ParameterSyncService { + /** + * Round floating-point numbers to avoid JavaScript precision issues + */ + private static roundFloatingPoint(value: ParameterValue): ParameterValue { + return normalizeFloatingPoint(value) as ParameterValue; + } + + /** + * Extract server default parameters that can be synced + */ + static extractServerDefaults( + serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null + ): ParameterRecord { + if (!serverParams) return {}; + + const extracted: ParameterRecord = {}; + + for (const param of SYNCABLE_PARAMETERS) { + if (param.canSync && param.serverKey in serverParams) { + const value = (serverParams as unknown as Record)[param.serverKey]; + if (value !== undefined) { + // Apply precision rounding to avoid JavaScript floating-point issues + extracted[param.key] = this.roundFloatingPoint(value); + } + } + } + + // Handle samplers array conversion to string + if (serverParams.samplers && Array.isArray(serverParams.samplers)) { + extracted.samplers = serverParams.samplers.join(';'); + } + + return extracted; + } + + /** + * Merge server defaults with current user settings + * Returns updated settings that respect user overrides while using server defaults + */ + static mergeWithServerDefaults( + currentSettings: ParameterRecord, + serverDefaults: ParameterRecord, + userOverrides: Set = new Set() + ): ParameterRecord { + const merged = { ...currentSettings }; + + for (const [key, serverValue] of Object.entries(serverDefaults)) { + // Only update if user hasn't explicitly overridden this parameter + if (!userOverrides.has(key)) { + merged[key] = this.roundFloatingPoint(serverValue); + } + } + + return merged; + } + + /** + * Get parameter information including source and values + */ + static getParameterInfo( + key: string, + currentValue: ParameterValue, + propsDefaults: ParameterRecord, + userOverrides: Set + ): ParameterInfo { + const hasPropsDefault = propsDefaults[key] !== undefined; + const isUserOverride = userOverrides.has(key); + + // Simple logic: either using default (from props) or custom (user override) + const source: ParameterSource = isUserOverride ? 'custom' : 'default'; + + return { + value: currentValue, + source, + serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility + userOverride: isUserOverride ? currentValue : undefined + }; + } + + /** + * Check if a parameter can be synced from server + */ + static canSyncParameter(key: string): boolean { + return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync); + } + + /** + * Get all syncable parameter keys + */ + static getSyncableParameterKeys(): string[] { + return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key); + } + + /** + * Validate server parameter value + */ + static validateServerParameter(key: string, value: ParameterValue): boolean { + const param = SYNCABLE_PARAMETERS.find((p) => p.key === key); + if (!param) return false; + + switch (param.type) { + case 'number': + return typeof value === 'number' && !isNaN(value); + case 'string': + return typeof value === 'string'; + case 'boolean': + return typeof value === 'boolean'; + default: + return false; + } + } + + /** + * Create a diff between current settings and server defaults + */ + static createParameterDiff( + currentSettings: ParameterRecord, + serverDefaults: ParameterRecord + ): Record { + const diff: Record< + string, + { current: ParameterValue; server: ParameterValue; differs: boolean } + > = {}; + + for (const key of this.getSyncableParameterKeys()) { + const currentValue = currentSettings[key]; + const serverValue = serverDefaults[key]; + + if (serverValue !== undefined) { + diff[key] = { + current: currentValue, + server: serverValue, + differs: currentValue !== serverValue + }; + } + } + + return diff; + } +} diff --git a/tools/server/webui/src/lib/stores/server.svelte.ts b/tools/server/webui/src/lib/stores/server.svelte.ts index 0b6855404c..1fd4afb040 100644 --- a/tools/server/webui/src/lib/stores/server.svelte.ts +++ b/tools/server/webui/src/lib/stores/server.svelte.ts @@ -125,6 +125,12 @@ class ServerStore { return this._slotsEndpointAvailable; } + get serverDefaultParams(): + | ApiLlamaCppServerProps['default_generation_settings']['params'] + | null { + return this._serverProps?.default_generation_settings?.params || null; + } + /** * Check if slots endpoint is available based on server properties and endpoint support */ @@ -273,3 +279,4 @@ export const supportedModalities = () => serverStore.supportedModalities; export const supportsVision = () => serverStore.supportsVision; export const supportsAudio = () => serverStore.supportsAudio; export const slotsEndpointAvailable = () => serverStore.slotsEndpointAvailable; +export const serverDefaultParams = () => serverStore.serverDefaultParams; diff --git a/tools/server/webui/src/lib/stores/settings.svelte.ts b/tools/server/webui/src/lib/stores/settings.svelte.ts index e5bc5ca9c9..b330cbb4bf 100644 --- a/tools/server/webui/src/lib/stores/settings.svelte.ts +++ b/tools/server/webui/src/lib/stores/settings.svelte.ts @@ -33,11 +33,25 @@ import { browser } from '$app/environment'; import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config'; +import { normalizeFloatingPoint } from '$lib/utils/precision'; +import { ParameterSyncService } from '$lib/services/parameter-sync'; +import { serverStore } from '$lib/stores/server.svelte'; +import { setConfigValue, getConfigValue, configToParameterRecord } from '$lib/utils/config-helpers'; class SettingsStore { config = $state({ ...SETTING_CONFIG_DEFAULT }); theme = $state('auto'); isInitialized = $state(false); + userOverrides = $state>(new Set()); + + /** + * Helper method to get server defaults with null safety + * Centralizes the pattern of getting and extracting server defaults + */ + private getServerDefaults(): Record { + const serverParams = serverStore.serverDefaultParams; + return serverParams ? ParameterSyncService.extractServerDefaults(serverParams) : {}; + } constructor() { if (browser) { @@ -67,14 +81,20 @@ class SettingsStore { try { const savedVal = JSON.parse(localStorage.getItem('config') || '{}'); + // Merge with defaults to prevent breaking changes this.config = { ...SETTING_CONFIG_DEFAULT, ...savedVal }; + + // Load user overrides + const savedOverrides = JSON.parse(localStorage.getItem('userOverrides') || '[]'); + this.userOverrides = new Set(savedOverrides); } catch (error) { console.warn('Failed to parse config from localStorage, using defaults:', error); this.config = { ...SETTING_CONFIG_DEFAULT }; + this.userOverrides = new Set(); } } @@ -86,14 +106,30 @@ class SettingsStore { this.theme = localStorage.getItem('theme') || 'auto'; } - /** * Update a specific configuration setting * @param key - The configuration key to update * @param value - The new value for the configuration key */ - updateConfig(key: K, value: SettingsConfigType[K]) { + updateConfig(key: K, value: SettingsConfigType[K]): void { this.config[key] = value; + + if (ParameterSyncService.canSyncParameter(key as string)) { + const propsDefaults = this.getServerDefaults(); + const propsDefault = propsDefaults[key as string]; + + if (propsDefault !== undefined) { + const normalizedValue = normalizeFloatingPoint(value); + const normalizedDefault = normalizeFloatingPoint(propsDefault); + + if (normalizedValue === normalizedDefault) { + this.userOverrides.delete(key as string); + } else { + this.userOverrides.add(key as string); + } + } + } + this.saveConfig(); } @@ -103,6 +139,26 @@ class SettingsStore { */ updateMultipleConfig(updates: Partial) { Object.assign(this.config, updates); + + const propsDefaults = this.getServerDefaults(); + + for (const [key, value] of Object.entries(updates)) { + if (ParameterSyncService.canSyncParameter(key)) { + const propsDefault = propsDefaults[key]; + + if (propsDefault !== undefined) { + const normalizedValue = normalizeFloatingPoint(value); + const normalizedDefault = normalizeFloatingPoint(propsDefault); + + if (normalizedValue === normalizedDefault) { + this.userOverrides.delete(key); + } else { + this.userOverrides.add(key); + } + } + } + } + this.saveConfig(); } @@ -114,6 +170,8 @@ class SettingsStore { try { localStorage.setItem('config', JSON.stringify(this.config)); + + localStorage.setItem('userOverrides', JSON.stringify(Array.from(this.userOverrides))); } catch (error) { console.error('Failed to save config to localStorage:', error); } @@ -185,6 +243,129 @@ class SettingsStore { getAllConfig(): SettingsConfigType { return { ...this.config }; } + + /** + * Initialize settings with props defaults when server properties are first loaded + * This sets up the default values from /props endpoint + */ + syncWithServerDefaults(): void { + const serverParams = serverStore.serverDefaultParams; + if (!serverParams) { + console.warn('No server parameters available for initialization'); + + return; + } + + const propsDefaults = this.getServerDefaults(); + + for (const [key, propsValue] of Object.entries(propsDefaults)) { + const currentValue = getConfigValue(this.config, key); + + const normalizedCurrent = normalizeFloatingPoint(currentValue); + const normalizedDefault = normalizeFloatingPoint(propsValue); + + if (normalizedCurrent === normalizedDefault) { + this.userOverrides.delete(key); + setConfigValue(this.config, key, propsValue); + } else if (!this.userOverrides.has(key)) { + setConfigValue(this.config, key, propsValue); + } + } + + this.saveConfig(); + console.log('Settings initialized with props defaults:', propsDefaults); + console.log('Current user overrides after sync:', Array.from(this.userOverrides)); + } + + /** + * Clear all user overrides (for debugging) + */ + clearAllUserOverrides(): void { + this.userOverrides.clear(); + this.saveConfig(); + console.log('Cleared all user overrides'); + } + + /** + * Reset all parameters to their default values (from props) + * This is used by the "Reset to Default" functionality + * Prioritizes server defaults from /props, falls back to webui defaults + */ + forceSyncWithServerDefaults(): void { + const propsDefaults = this.getServerDefaults(); + const syncableKeys = ParameterSyncService.getSyncableParameterKeys(); + + for (const key of syncableKeys) { + if (propsDefaults[key] !== undefined) { + const normalizedValue = normalizeFloatingPoint(propsDefaults[key]); + + setConfigValue(this.config, key, normalizedValue); + } else { + if (key in SETTING_CONFIG_DEFAULT) { + const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key); + + setConfigValue(this.config, key, defaultValue); + } + } + + this.userOverrides.delete(key); + } + + this.saveConfig(); + } + + /** + * Get parameter information including source for a specific parameter + */ + getParameterInfo(key: string) { + const propsDefaults = this.getServerDefaults(); + const currentValue = getConfigValue(this.config, key); + + return ParameterSyncService.getParameterInfo( + key, + currentValue ?? '', + propsDefaults, + this.userOverrides + ); + } + + /** + * Reset a parameter to server default (or webui default if no server default) + */ + resetParameterToServerDefault(key: string): void { + const serverDefaults = this.getServerDefaults(); + + if (serverDefaults[key] !== undefined) { + const value = normalizeFloatingPoint(serverDefaults[key]); + + this.config[key as keyof SettingsConfigType] = + value as SettingsConfigType[keyof SettingsConfigType]; + } else { + if (key in SETTING_CONFIG_DEFAULT) { + const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key); + + setConfigValue(this.config, key, defaultValue); + } + } + + this.userOverrides.delete(key); + this.saveConfig(); + } + + /** + * Get diff between current settings and server defaults + */ + getParameterDiff() { + const serverDefaults = this.getServerDefaults(); + if (Object.keys(serverDefaults).length === 0) return {}; + + const configAsRecord = configToParameterRecord( + this.config, + ParameterSyncService.getSyncableParameterKeys() + ); + + return ParameterSyncService.createParameterDiff(configAsRecord, serverDefaults); + } } // Create and export the settings store instance @@ -204,3 +385,11 @@ export const resetTheme = settingsStore.resetTheme.bind(settingsStore); export const resetAll = settingsStore.resetAll.bind(settingsStore); export const getConfig = settingsStore.getConfig.bind(settingsStore); export const getAllConfig = settingsStore.getAllConfig.bind(settingsStore); +export const syncWithServerDefaults = settingsStore.syncWithServerDefaults.bind(settingsStore); +export const forceSyncWithServerDefaults = + settingsStore.forceSyncWithServerDefaults.bind(settingsStore); +export const getParameterInfo = settingsStore.getParameterInfo.bind(settingsStore); +export const resetParameterToServerDefault = + settingsStore.resetParameterToServerDefault.bind(settingsStore); +export const getParameterDiff = settingsStore.getParameterDiff.bind(settingsStore); +export const clearAllUserOverrides = settingsStore.clearAllUserOverrides.bind(settingsStore); diff --git a/tools/server/webui/src/lib/utils/config-helpers.ts b/tools/server/webui/src/lib/utils/config-helpers.ts new file mode 100644 index 0000000000..2d023f8d5c --- /dev/null +++ b/tools/server/webui/src/lib/utils/config-helpers.ts @@ -0,0 +1,53 @@ +/** + * Type-safe configuration helpers + * + * Provides utilities for safely accessing and modifying configuration objects + * with dynamic keys while maintaining TypeScript type safety. + */ + +import type { SettingsConfigType } from '$lib/types/settings'; + +/** + * Type-safe helper to access config properties dynamically + * Provides better type safety than direct casting to Record + */ +export function setConfigValue( + config: T, + key: string, + value: unknown +): void { + if (key in config) { + (config as Record)[key] = value; + } +} + +/** + * Type-safe helper to get config values dynamically + */ +export function getConfigValue( + config: T, + key: string +): string | number | boolean | undefined { + const value = (config as Record)[key]; + return value as string | number | boolean | undefined; +} + +/** + * Convert a SettingsConfigType to a ParameterRecord for specific keys + * Useful for parameter synchronization operations + */ +export function configToParameterRecord( + config: T, + keys: string[] +): Record { + const record: Record = {}; + + for (const key of keys) { + const value = getConfigValue(config, key); + if (value !== undefined) { + record[key] = value; + } + } + + return record; +} diff --git a/tools/server/webui/src/lib/utils/precision.ts b/tools/server/webui/src/lib/utils/precision.ts new file mode 100644 index 0000000000..6da200cf0b --- /dev/null +++ b/tools/server/webui/src/lib/utils/precision.ts @@ -0,0 +1,25 @@ +/** + * Floating-point precision utilities + * + * Provides functions to normalize floating-point numbers for consistent comparison + * and display, addressing JavaScript's floating-point precision issues. + */ + +import { PRECISION_MULTIPLIER } from '$lib/constants/precision'; + +/** + * Normalize floating-point numbers for consistent comparison + * Addresses JavaScript floating-point precision issues (e.g., 0.949999988079071 → 0.95) + */ +export function normalizeFloatingPoint(value: unknown): unknown { + return typeof value === 'number' + ? Math.round(value * PRECISION_MULTIPLIER) / PRECISION_MULTIPLIER + : value; +} + +/** + * Type-safe version that only accepts numbers + */ +export function normalizeNumber(value: number): number { + return Math.round(value * PRECISION_MULTIPLIER) / PRECISION_MULTIPLIER; +} diff --git a/tools/server/webui/src/routes/+layout.svelte b/tools/server/webui/src/routes/+layout.svelte index 0245cf3abc..8912f642ce 100644 --- a/tools/server/webui/src/routes/+layout.svelte +++ b/tools/server/webui/src/routes/+layout.svelte @@ -9,7 +9,7 @@ } from '$lib/stores/chat.svelte'; import * as Sidebar from '$lib/components/ui/sidebar/index.js'; import { serverStore } from '$lib/stores/server.svelte'; - import { config } from '$lib/stores/settings.svelte'; + import { config, settingsStore } from '$lib/stores/settings.svelte'; import { ModeWatcher } from 'mode-watcher'; import { Toaster } from 'svelte-sonner'; import { goto } from '$app/navigation'; @@ -95,6 +95,15 @@ serverStore.fetchServerProps(); }); + // Sync settings when server props are loaded + $effect(() => { + const serverProps = serverStore.serverProps; + + if (serverProps?.default_generation_settings?.params) { + settingsStore.syncWithServerDefaults(); + } + }); + // Monitor API key changes and redirect to error page if removed or changed when required $effect(() => { const apiKey = config().apiKey;