Merge branch 'master' into cuda_graph_plan
This commit is contained in:
commit
4bbe5b1e59
|
|
@ -22,6 +22,7 @@ Legend:
|
|||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -41,6 +42,7 @@ Legend:
|
|||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
|
|
@ -82,6 +84,7 @@ Legend:
|
|||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -108,5 +111,6 @@ Legend:
|
|||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -59,6 +59,14 @@
|
|||
"CPU","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
|
|
@ -119,6 +127,14 @@
|
|||
"CPU","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=0","support","1","yes","CPU"
|
||||
"CPU","REGLU","type=f16,ne_a=[5,7,11,13],v=0,swapped=0","support","1","yes","CPU"
|
||||
"CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=1","support","1","yes","CPU"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -577,6 +577,10 @@ extern "C" {
|
|||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
GGML_UNARY_OP_FLOOR,
|
||||
GGML_UNARY_OP_CEIL,
|
||||
GGML_UNARY_OP_ROUND,
|
||||
GGML_UNARY_OP_TRUNC,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
|
|
@ -1151,6 +1155,46 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_floor(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_floor_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ceil(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ceil_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_round(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_round_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
/**
|
||||
* Truncates the fractional part of each element in the tensor (towards zero).
|
||||
* For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
|
||||
* Similar to std::trunc in C/C++.
|
||||
*/
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_trunc(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_trunc_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
|
||||
|
||||
// xIELU activation function
|
||||
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
|
||||
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
|
||||
|
|
|
|||
|
|
@ -2184,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary(
|
|||
{
|
||||
ggml_compute_forward_exp(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
{
|
||||
ggml_compute_forward_floor(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
{
|
||||
ggml_compute_forward_ceil(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
{
|
||||
ggml_compute_forward_round(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
{
|
||||
ggml_compute_forward_trunc(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
|
|
|
|||
|
|
@ -73,6 +73,22 @@ static inline float op_log(float x) {
|
|||
return logf(x);
|
||||
}
|
||||
|
||||
static inline float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
|
||||
static inline float op_ceil(float x) {
|
||||
return ceilf(x);
|
||||
}
|
||||
|
||||
static inline float op_round(float x) {
|
||||
return roundf(x);
|
||||
}
|
||||
|
||||
static inline float op_trunc(float x) {
|
||||
return truncf(x);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||
|
|
@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
|||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_floor>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_ceil>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_round>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_trunc>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float alpha_n = ggml_get_op_params_f32(dst, 1);
|
||||
const float alpha_p = ggml_get_op_params_f32(dst, 2);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
|||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -273,6 +273,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
}
|
||||
|
||||
// Temporary performance fix:
|
||||
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
|
||||
// TODO: Check for future drivers the default scheduling strategy and
|
||||
// remove this call again when cudaDeviceScheduleSpin is default.
|
||||
if (prop.major == 12 && prop.minor == 1) {
|
||||
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
|
||||
}
|
||||
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
|
|
@ -3644,9 +3653,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
#include <Metal/Metal.h>
|
||||
|
||||
#include <stdatomic.h>
|
||||
|
||||
#ifndef TARGET_OS_VISION
|
||||
#define TARGET_OS_VISION 0
|
||||
#endif
|
||||
|
|
@ -22,6 +24,9 @@
|
|||
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
||||
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
||||
|
||||
// virtual address for GPU memory allocations
|
||||
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
||||
|
||||
#if !GGML_METAL_EMBED_LIBRARY
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
|
|
@ -657,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
|
@ -827,7 +833,7 @@ struct ggml_metal_buffer_wrapper {
|
|||
};
|
||||
|
||||
struct ggml_metal_buffer {
|
||||
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
|
||||
void * all_data;
|
||||
size_t all_size;
|
||||
|
||||
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
||||
|
|
@ -965,14 +971,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
|||
if (shared) {
|
||||
res->all_data = ggml_metal_host_malloc(size_aligned);
|
||||
res->is_shared = true;
|
||||
res->owned = true;
|
||||
} else {
|
||||
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
|
||||
res->all_data = (void *) 0x000000400ULL;
|
||||
// use virtual address from g_addr_device counter
|
||||
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
|
||||
res->is_shared = false;
|
||||
}
|
||||
res->all_size = size_aligned;
|
||||
|
||||
res->owned = true;
|
||||
|
||||
res->device = ggml_metal_device_get_obj(dev);
|
||||
res->queue = ggml_metal_device_get_queue(dev);
|
||||
|
||||
|
|
@ -983,15 +990,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
|||
res->buffers[0].metal = nil;
|
||||
|
||||
if (size_aligned > 0) {
|
||||
if (props_dev->use_shared_buffers &&shared) {
|
||||
if (props_dev->use_shared_buffers && shared) {
|
||||
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
||||
length:size_aligned
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
} else {
|
||||
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
||||
|
||||
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1139,7 +1144,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
|
|||
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
memset((char *) tensor->data + offset, value, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1168,7 +1173,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
|
|||
|
||||
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
memcpy((char *) tensor->data + offset, data, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1223,7 +1228,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
|||
|
||||
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
memcpy(data, (const char *) tensor->data + offset, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -251,6 +251,7 @@ typedef struct {
|
|||
int32_t sect_1;
|
||||
int32_t sect_2;
|
||||
int32_t sect_3;
|
||||
bool src2;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
|
|
|
|||
|
|
@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, (int) n);
|
||||
|
||||
const int nsg = (nth + 31) / 32;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -2969,6 +2982,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|||
/* sect_1 =*/ sect_1,
|
||||
/* sect_2 =*/ sect_2,
|
||||
/* sect_3 =*/ sect_3,
|
||||
/* src2 =*/ op->src[2] != nullptr,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
|
|
|
|||
|
|
@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
|
|||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
if (tiitg != 0) {
|
||||
if (args.np == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float acc = 0.0f;
|
||||
for (ulong i = 0; i < args.np; ++i) {
|
||||
acc += src0[i];
|
||||
const uint nsg = (ntg.x + 31) / 32;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||
sumf += src0[i0];
|
||||
}
|
||||
|
||||
dst[0] = acc;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float total = 0;
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0;
|
||||
|
||||
if (tpitg.x < nsg) {
|
||||
v = shmem_f32[tpitg.x];
|
||||
}
|
||||
|
||||
total = simd_sum(v);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst[0] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
|
|
@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
|
|||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
|
|||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
@ -3872,7 +3902,7 @@ kernel void kernel_rope_multi(
|
|||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
@ -3939,7 +3969,7 @@ kernel void kernel_rope_vision(
|
|||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS
|
|||
mul_mv_id_mxfp4_f32_flat
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul
|
||||
norm
|
||||
relu
|
||||
|
|
|
|||
|
|
@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_program program_mul_mv_id_mxfp4_f32_flat;
|
||||
cl_program program_mul_mm_f32_f32_l4_lm;
|
||||
cl_program program_mul_mm_f16_f32_l4_lm;
|
||||
cl_program program_mul_mm_q8_0_f32_l4_lm;
|
||||
|
||||
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
|
||||
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
|
||||
|
|
@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
|
||||
std::vector<ProfilingInfo> profiling_info;
|
||||
|
||||
|
|
@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q8_0_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q8_0_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
|
|
@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#define ACC_TYPE4 float4
|
||||
#define DATA_TYPE float
|
||||
#define DATA_TYPE4 float4
|
||||
#define MASK_DATA_TYPE half
|
||||
#define CONVERT_ACC4(x) (x)
|
||||
#define CONVERT_DATA4(x) (x)
|
||||
|
||||
|
|
@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
|
|||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
|
|
@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
|
|||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
||||
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
|
|
@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
|
|||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
||||
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
|
|
|
|||
|
|
@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
|
|||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (loadc_a + l < ne01) {
|
||||
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
if (loadc_b + l < ne11) {
|
||||
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
|
|
|||
|
|
@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
|
|||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
|
||||
if (loadc_a + l < ne01) {
|
||||
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
if (loadc_b + l < ne11) {
|
||||
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,154 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 4
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q8_0_f32_l4_lm(
|
||||
global char4 * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 8;
|
||||
int iqs = idx % 8;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
global char4 * qs = src0_q + ib*8 + iqs;
|
||||
char4 q = *qs;
|
||||
float4 v = convert_float4(q)*d;
|
||||
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1144,9 +1144,13 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"EXP",
|
||||
"GELU_ERF",
|
||||
"XIELU",
|
||||
"FLOOR",
|
||||
"CEIL",
|
||||
"ROUND",
|
||||
"TRUNC",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20");
|
||||
|
||||
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||
"REGLU",
|
||||
|
|
@ -2749,6 +2753,62 @@ static struct ggml_tensor * ggml_glu_impl(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_floor
|
||||
|
||||
struct ggml_tensor * ggml_floor(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_floor_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);
|
||||
}
|
||||
|
||||
// ggml_ceil
|
||||
|
||||
struct ggml_tensor * ggml_ceil(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_ceil_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);
|
||||
}
|
||||
|
||||
//ggml_round
|
||||
|
||||
struct ggml_tensor * ggml_round(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_round_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND);
|
||||
}
|
||||
|
||||
//ggml_trunc
|
||||
|
||||
struct ggml_tensor * ggml_trunc(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_trunc_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_glu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
|||
tensor.tensor_type not in (
|
||||
gguf.GGMLQuantizationType.F32,
|
||||
gguf.GGMLQuantizationType.F16,
|
||||
gguf.GGMLQuantizationType.BF16,
|
||||
):
|
||||
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
||||
logger.info(f"* Preparing to convert from {file_endian} to {order}")
|
||||
|
|
@ -148,6 +149,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
|||
|
||||
# restore old shape in case it's ever used
|
||||
tensor.data.resize(oldshape)
|
||||
elif tensor.tensor_type == gguf.GGMLQuantizationType.BF16:
|
||||
# Special case for BF16
|
||||
# It is 2-bytes data, but by default view loads it as 1-byte data.
|
||||
# Change to correct view before byteswapping.
|
||||
tensor.data.view(dtype=np.uint16).byteswap(inplace=True)
|
||||
else:
|
||||
# Handle other tensor types
|
||||
tensor.data.byteswap(inplace=True)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <map>
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
|
||||
{ LLM_ARCH_LLAMA, "llama" },
|
||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||
{ LLM_ARCH_DECI, "deci" },
|
||||
|
|
@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
|
||||
{
|
||||
LLM_ARCH_CLIP,
|
||||
{},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
//
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_CLIP,
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
LLM_ARCH_DECI,
|
||||
|
|
|
|||
|
|
@ -478,7 +478,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
|
||||
|
||||
// everything past this point is not vocab-related
|
||||
if (hparams.vocab_only) {
|
||||
// for CLIP models, we only need to load tensors, no hparams
|
||||
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -20013,6 +20014,7 @@ int32_t llama_n_head(const llama_model * model) {
|
|||
llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
case LLM_ARCH_CLIP:
|
||||
case LLM_ARCH_GPT2:
|
||||
case LLM_ARCH_GPTJ:
|
||||
case LLM_ARCH_MPT:
|
||||
|
|
|
|||
|
|
@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
});
|
||||
}
|
||||
|
||||
bool is_clip_model = false;
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
|
|
@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
|
||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0)
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
|
|
@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
||||
// do not quantize specific multimodal tensors
|
||||
quantize &= name.find(".position_embd.") == std::string::npos;
|
||||
|
||||
ggml_type new_type;
|
||||
void * new_data;
|
||||
size_t new_size;
|
||||
|
|
|
|||
|
|
@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
|||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
if (model.arch == LLM_ARCH_CLIP) {
|
||||
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
|
||||
}
|
||||
try {
|
||||
model.load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
|
|
|
|||
|
|
@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
|
|||
struct test_sum : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const std::array<int64_t, 4> permute;
|
||||
bool _use_permute;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
std::string v = VARS_TO_STR2(type, ne);
|
||||
if (_use_permute) v += "," + VAR_TO_STR(permute);
|
||||
return v;
|
||||
}
|
||||
|
||||
test_sum(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3})
|
||||
: type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
||||
std::array<int64_t, 4> permute = {0, 0, 0, 0})
|
||||
: type(type), ne(ne), permute(permute),
|
||||
_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (_use_permute) {
|
||||
a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
|
||||
ggml_set_name(a, "a_permuted");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_sum(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
|
|
@ -6354,6 +6365,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
{
|
||||
// Test paths in OpenCL
|
||||
std::vector<int> ns = {32, 64, 128, 256, 512, 1024, 4096};
|
||||
std::vector<int> ks = {896, 1536, 4096};
|
||||
for (auto n : ns) {
|
||||
for (auto k : ks) {
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
|
|
@ -6724,6 +6748,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
||||
|
|
@ -6734,6 +6761,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -3812,7 +3812,7 @@ struct server_context {
|
|||
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
|
|
@ -3839,14 +3839,14 @@ struct server_context {
|
|||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = common_token_to_piece(ctx, token);
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = common_token_to_piece(ctx, token);
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
|
|
@ -3860,7 +3860,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
|
|
@ -4028,7 +4028,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
|
||||
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
|
||||
|
||||
|
|
|
|||
|
|
@ -1237,9 +1237,10 @@ public:
|
|||
// allowed to resize ^ ^
|
||||
// disallowed to resize ^ ^ ^
|
||||
if (n > 0) {
|
||||
llama_token last_token = tokens[n - 1];
|
||||
// make sure we never remove tokens in the middle of an image
|
||||
if (last_token == LLAMA_TOKEN_NULL) {
|
||||
// note that the case where we keep a full image at the end is allowed:
|
||||
// tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
|
||||
if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
|
||||
find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@
|
|||
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
|
||||
import { config, updateMultipleConfig, resetConfig } from '$lib/stores/settings.svelte';
|
||||
import { config, updateMultipleConfig } from '$lib/stores/settings.svelte';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
|
|
@ -267,16 +266,13 @@
|
|||
}
|
||||
|
||||
function handleReset() {
|
||||
resetConfig();
|
||||
localConfig = { ...config() };
|
||||
|
||||
localConfig = { ...SETTING_CONFIG_DEFAULT };
|
||||
|
||||
setMode(SETTING_CONFIG_DEFAULT.theme as 'light' | 'dark' | 'system');
|
||||
originalTheme = SETTING_CONFIG_DEFAULT.theme as string;
|
||||
setMode(localConfig.theme as 'light' | 'dark' | 'system');
|
||||
originalTheme = localConfig.theme as string;
|
||||
}
|
||||
|
||||
function handleSave() {
|
||||
// Validate custom JSON if provided
|
||||
if (localConfig.custom && typeof localConfig.custom === 'string' && localConfig.custom.trim()) {
|
||||
try {
|
||||
JSON.parse(localConfig.custom);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { RotateCcw } from '@lucide/svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
|
|
@ -6,6 +7,9 @@
|
|||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||
import { supportsVision } from '$lib/stores/server.svelte';
|
||||
import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import ParameterSourceIndicator from './ParameterSourceIndicator.svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -16,22 +20,77 @@
|
|||
}
|
||||
|
||||
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
|
||||
|
||||
// Helper function to get parameter source info for syncable parameters
|
||||
function getParameterSourceInfo(key: string) {
|
||||
if (!ParameterSyncService.canSyncParameter(key)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return getParameterInfo(key);
|
||||
}
|
||||
</script>
|
||||
|
||||
{#each fields as field (field.key)}
|
||||
<div class="space-y-2">
|
||||
{#if field.type === 'input'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
{@const paramInfo = getParameterSourceInfo(field.key)}
|
||||
{@const currentValue = String(localConfig[field.key] ?? '')}
|
||||
{@const propsDefault = paramInfo?.serverDefault}
|
||||
{@const isCustomRealTime = (() => {
|
||||
if (!paramInfo || propsDefault === undefined) return false;
|
||||
|
||||
<Input
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="w-full md:max-w-md"
|
||||
/>
|
||||
// Apply same rounding logic for real-time comparison
|
||||
const inputValue = currentValue;
|
||||
const numericInput = parseFloat(inputValue);
|
||||
const normalizedInput = !isNaN(numericInput)
|
||||
? Math.round(numericInput * 1000000) / 1000000
|
||||
: inputValue;
|
||||
const normalizedDefault =
|
||||
typeof propsDefault === 'number'
|
||||
? Math.round(propsDefault * 1000000) / 1000000
|
||||
: propsDefault;
|
||||
|
||||
return normalizedInput !== normalizedDefault;
|
||||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="relative w-full md:max-w-md">
|
||||
<Input
|
||||
id={field.key}
|
||||
value={currentValue}
|
||||
oninput={(e) => {
|
||||
// Update local config immediately for real-time badge feedback
|
||||
onConfigChange(field.key, e.currentTarget.value);
|
||||
}}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="w-full {isCustomRealTime ? 'pr-8' : ''}"
|
||||
/>
|
||||
{#if isCustomRealTime}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
resetParameterToServerDefault(field.key);
|
||||
// Trigger UI update by calling onConfigChange with the default value
|
||||
const defaultValue = propsDefault ?? SETTING_CONFIG_DEFAULT[field.key];
|
||||
onConfigChange(field.key, String(defaultValue));
|
||||
}}
|
||||
class="absolute top-1/2 right-2 inline-flex h-5 w-5 -translate-y-1/2 items-center justify-center rounded transition-colors hover:bg-muted"
|
||||
aria-label="Reset to default"
|
||||
title="Reset to default"
|
||||
>
|
||||
<RotateCcw class="h-3 w-3" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
|
|
@ -59,14 +118,28 @@
|
|||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
opt.value === localConfig[field.key]
|
||||
)}
|
||||
{@const paramInfo = getParameterSourceInfo(field.key)}
|
||||
{@const currentValue = localConfig[field.key]}
|
||||
{@const propsDefault = paramInfo?.serverDefault}
|
||||
{@const isCustomRealTime = (() => {
|
||||
if (!paramInfo || propsDefault === undefined) return false;
|
||||
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
// For select fields, do direct comparison (no rounding needed)
|
||||
return currentValue !== propsDefault;
|
||||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Select.Root
|
||||
type="single"
|
||||
value={localConfig[field.key]}
|
||||
value={currentValue}
|
||||
onValueChange={(value) => {
|
||||
if (field.key === 'theme' && value && onThemeChange) {
|
||||
onThemeChange(value);
|
||||
|
|
@ -75,16 +148,34 @@
|
|||
}
|
||||
}}
|
||||
>
|
||||
<Select.Trigger class="w-full md:w-auto md:max-w-md">
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
<div class="relative w-full md:w-auto md:max-w-md">
|
||||
<Select.Trigger class="w-full">
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
{selectedOption?.label || `Select ${field.label.toLowerCase()}`}
|
||||
</div>
|
||||
</Select.Trigger>
|
||||
{selectedOption?.label || `Select ${field.label.toLowerCase()}`}
|
||||
</div>
|
||||
</Select.Trigger>
|
||||
{#if isCustomRealTime}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
resetParameterToServerDefault(field.key);
|
||||
// Trigger UI update by calling onConfigChange with the default value
|
||||
const defaultValue = propsDefault ?? SETTING_CONFIG_DEFAULT[field.key];
|
||||
onConfigChange(field.key, String(defaultValue));
|
||||
}}
|
||||
class="absolute top-1/2 right-8 inline-flex h-5 w-5 -translate-y-1/2 items-center justify-center rounded transition-colors hover:bg-muted"
|
||||
aria-label="Reset to default"
|
||||
title="Reset to default"
|
||||
>
|
||||
<RotateCcw class="h-3 w-3" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<Select.Content>
|
||||
{#if field.options}
|
||||
{#each field.options as option (option.value)}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { forceSyncWithServerDefaults } from '$lib/stores/settings.svelte';
|
||||
import { RotateCcw } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
onReset?: () => void;
|
||||
|
|
@ -16,7 +18,9 @@
|
|||
}
|
||||
|
||||
function handleConfirmReset() {
|
||||
forceSyncWithServerDefaults();
|
||||
onReset?.();
|
||||
|
||||
showResetDialog = false;
|
||||
}
|
||||
|
||||
|
|
@ -26,7 +30,13 @@
|
|||
</script>
|
||||
|
||||
<div class="flex justify-between border-t border-border/30 p-6">
|
||||
<Button variant="outline" onclick={handleResetClick}>Reset to default</Button>
|
||||
<div class="flex gap-2">
|
||||
<Button variant="outline" onclick={handleResetClick}>
|
||||
<RotateCcw class="h-3 w-3" />
|
||||
|
||||
Reset to default
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Button onclick={handleSave}>Save settings</Button>
|
||||
</div>
|
||||
|
|
@ -36,8 +46,9 @@
|
|||
<AlertDialog.Header>
|
||||
<AlertDialog.Title>Reset Settings to Default</AlertDialog.Title>
|
||||
<AlertDialog.Description>
|
||||
Are you sure you want to reset all settings to their default values? This action cannot be
|
||||
undone and will permanently remove all your custom configurations.
|
||||
Are you sure you want to reset all settings to their default values? This will reset all
|
||||
parameters to the values provided by the server's /props endpoint and remove all your custom
|
||||
configurations.
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
<AlertDialog.Footer>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
<script lang="ts">
|
||||
import { Wrench } from '@lucide/svelte';
|
||||
import { Badge } from '$lib/components/ui/badge';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="h-5 bg-orange-100 px-1.5 py-0.5 text-xs text-orange-800 dark:bg-orange-900 dark:text-orange-200 {className}"
|
||||
>
|
||||
<Wrench class="mr-1 h-3 w-3" />
|
||||
Custom
|
||||
</Badge>
|
||||
|
|
@ -25,6 +25,7 @@ export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
|||
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte';
|
||||
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
|
||||
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
|
||||
export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte';
|
||||
|
||||
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
export const PRECISION_MULTIPLIER = 1000000;
|
||||
export const PRECISION_DECIMAL_PLACES = 6;
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
import { describe, it, expect } from 'vitest';
|
||||
import { ParameterSyncService } from './parameter-sync';
|
||||
import type { ApiLlamaCppServerProps } from '$lib/types/api';
|
||||
|
||||
describe('ParameterSyncService', () => {
|
||||
describe('roundFloatingPoint', () => {
|
||||
it('should fix JavaScript floating-point precision issues', () => {
|
||||
// Test the specific values from the screenshot
|
||||
const mockServerParams = {
|
||||
top_p: 0.949999988079071,
|
||||
min_p: 0.009999999776482582,
|
||||
temperature: 0.800000011920929,
|
||||
top_k: 40,
|
||||
samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature']
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Add other required fields to match the API type
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
max_tokens: -1,
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
// Check that the problematic floating-point values are rounded correctly
|
||||
expect(result.top_p).toBe(0.95);
|
||||
expect(result.min_p).toBe(0.01);
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.top_k).toBe(40); // Integer should remain unchanged
|
||||
expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature');
|
||||
});
|
||||
|
||||
it('should preserve non-numeric values', () => {
|
||||
const mockServerParams = {
|
||||
samplers: ['top_k', 'temperature'],
|
||||
max_tokens: -1,
|
||||
temperature: 0.7
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Minimal required fields
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
top_k: 40,
|
||||
top_p: 0.95,
|
||||
min_p: 0.05,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
expect(result.samplers).toBe('top_k;temperature');
|
||||
expect(result.max_tokens).toBe(-1);
|
||||
expect(result.temperature).toBe(0.7);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
/**
|
||||
* ParameterSyncService - Handles synchronization between server defaults and user settings
|
||||
*
|
||||
* This service manages the complex logic of merging server-provided default parameters
|
||||
* with user-configured overrides, ensuring the UI reflects the actual server state
|
||||
* while preserving user customizations.
|
||||
*
|
||||
* **Key Responsibilities:**
|
||||
* - Extract syncable parameters from server props
|
||||
* - Merge server defaults with user overrides
|
||||
* - Track parameter sources (server, user, default)
|
||||
* - Provide sync utilities for settings store integration
|
||||
*/
|
||||
|
||||
import type { ApiLlamaCppServerProps } from '$lib/types/api';
|
||||
import { normalizeFloatingPoint } from '$lib/utils/precision';
|
||||
|
||||
export type ParameterSource = 'default' | 'custom';
|
||||
export type ParameterValue = string | number | boolean;
|
||||
export type ParameterRecord = Record<string, ParameterValue>;
|
||||
|
||||
export interface ParameterInfo {
|
||||
value: string | number | boolean;
|
||||
source: ParameterSource;
|
||||
serverDefault?: string | number | boolean;
|
||||
userOverride?: string | number | boolean;
|
||||
}
|
||||
|
||||
export interface SyncableParameter {
|
||||
key: string;
|
||||
serverKey: string;
|
||||
type: 'number' | 'string' | 'boolean';
|
||||
canSync: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping of webui setting keys to server parameter keys
|
||||
* Only parameters that should be synced from server are included
|
||||
*/
|
||||
export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
{ key: 'temperature', serverKey: 'temperature', type: 'number', canSync: true },
|
||||
{ key: 'top_k', serverKey: 'top_k', type: 'number', canSync: true },
|
||||
{ key: 'top_p', serverKey: 'top_p', type: 'number', canSync: true },
|
||||
{ key: 'min_p', serverKey: 'min_p', type: 'number', canSync: true },
|
||||
{ key: 'dynatemp_range', serverKey: 'dynatemp_range', type: 'number', canSync: true },
|
||||
{ key: 'dynatemp_exponent', serverKey: 'dynatemp_exponent', type: 'number', canSync: true },
|
||||
{ key: 'xtc_probability', serverKey: 'xtc_probability', type: 'number', canSync: true },
|
||||
{ key: 'xtc_threshold', serverKey: 'xtc_threshold', type: 'number', canSync: true },
|
||||
{ key: 'typ_p', serverKey: 'typ_p', type: 'number', canSync: true },
|
||||
{ key: 'repeat_last_n', serverKey: 'repeat_last_n', type: 'number', canSync: true },
|
||||
{ key: 'repeat_penalty', serverKey: 'repeat_penalty', type: 'number', canSync: true },
|
||||
{ key: 'presence_penalty', serverKey: 'presence_penalty', type: 'number', canSync: true },
|
||||
{ key: 'frequency_penalty', serverKey: 'frequency_penalty', type: 'number', canSync: true },
|
||||
{ key: 'dry_multiplier', serverKey: 'dry_multiplier', type: 'number', canSync: true },
|
||||
{ key: 'dry_base', serverKey: 'dry_base', type: 'number', canSync: true },
|
||||
{ key: 'dry_allowed_length', serverKey: 'dry_allowed_length', type: 'number', canSync: true },
|
||||
{ key: 'dry_penalty_last_n', serverKey: 'dry_penalty_last_n', type: 'number', canSync: true },
|
||||
{ key: 'max_tokens', serverKey: 'max_tokens', type: 'number', canSync: true },
|
||||
{ key: 'samplers', serverKey: 'samplers', type: 'string', canSync: true }
|
||||
];
|
||||
|
||||
export class ParameterSyncService {
|
||||
/**
|
||||
* Round floating-point numbers to avoid JavaScript precision issues
|
||||
*/
|
||||
private static roundFloatingPoint(value: ParameterValue): ParameterValue {
|
||||
return normalizeFloatingPoint(value) as ParameterValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract server default parameters that can be synced
|
||||
*/
|
||||
static extractServerDefaults(
|
||||
serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null
|
||||
): ParameterRecord {
|
||||
if (!serverParams) return {};
|
||||
|
||||
const extracted: ParameterRecord = {};
|
||||
|
||||
for (const param of SYNCABLE_PARAMETERS) {
|
||||
if (param.canSync && param.serverKey in serverParams) {
|
||||
const value = (serverParams as unknown as Record<string, ParameterValue>)[param.serverKey];
|
||||
if (value !== undefined) {
|
||||
// Apply precision rounding to avoid JavaScript floating-point issues
|
||||
extracted[param.key] = this.roundFloatingPoint(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle samplers array conversion to string
|
||||
if (serverParams.samplers && Array.isArray(serverParams.samplers)) {
|
||||
extracted.samplers = serverParams.samplers.join(';');
|
||||
}
|
||||
|
||||
return extracted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge server defaults with current user settings
|
||||
* Returns updated settings that respect user overrides while using server defaults
|
||||
*/
|
||||
static mergeWithServerDefaults(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord,
|
||||
userOverrides: Set<string> = new Set()
|
||||
): ParameterRecord {
|
||||
const merged = { ...currentSettings };
|
||||
|
||||
for (const [key, serverValue] of Object.entries(serverDefaults)) {
|
||||
// Only update if user hasn't explicitly overridden this parameter
|
||||
if (!userOverrides.has(key)) {
|
||||
merged[key] = this.roundFloatingPoint(serverValue);
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get parameter information including source and values
|
||||
*/
|
||||
static getParameterInfo(
|
||||
key: string,
|
||||
currentValue: ParameterValue,
|
||||
propsDefaults: ParameterRecord,
|
||||
userOverrides: Set<string>
|
||||
): ParameterInfo {
|
||||
const hasPropsDefault = propsDefaults[key] !== undefined;
|
||||
const isUserOverride = userOverrides.has(key);
|
||||
|
||||
// Simple logic: either using default (from props) or custom (user override)
|
||||
const source: ParameterSource = isUserOverride ? 'custom' : 'default';
|
||||
|
||||
return {
|
||||
value: currentValue,
|
||||
source,
|
||||
serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility
|
||||
userOverride: isUserOverride ? currentValue : undefined
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a parameter can be synced from server
|
||||
*/
|
||||
static canSyncParameter(key: string): boolean {
|
||||
return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all syncable parameter keys
|
||||
*/
|
||||
static getSyncableParameterKeys(): string[] {
|
||||
return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate server parameter value
|
||||
*/
|
||||
static validateServerParameter(key: string, value: ParameterValue): boolean {
|
||||
const param = SYNCABLE_PARAMETERS.find((p) => p.key === key);
|
||||
if (!param) return false;
|
||||
|
||||
switch (param.type) {
|
||||
case 'number':
|
||||
return typeof value === 'number' && !isNaN(value);
|
||||
case 'string':
|
||||
return typeof value === 'string';
|
||||
case 'boolean':
|
||||
return typeof value === 'boolean';
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a diff between current settings and server defaults
|
||||
*/
|
||||
static createParameterDiff(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord
|
||||
): Record<string, { current: ParameterValue; server: ParameterValue; differs: boolean }> {
|
||||
const diff: Record<
|
||||
string,
|
||||
{ current: ParameterValue; server: ParameterValue; differs: boolean }
|
||||
> = {};
|
||||
|
||||
for (const key of this.getSyncableParameterKeys()) {
|
||||
const currentValue = currentSettings[key];
|
||||
const serverValue = serverDefaults[key];
|
||||
|
||||
if (serverValue !== undefined) {
|
||||
diff[key] = {
|
||||
current: currentValue,
|
||||
server: serverValue,
|
||||
differs: currentValue !== serverValue
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return diff;
|
||||
}
|
||||
}
|
||||
|
|
@ -125,6 +125,12 @@ class ServerStore {
|
|||
return this._slotsEndpointAvailable;
|
||||
}
|
||||
|
||||
get serverDefaultParams():
|
||||
| ApiLlamaCppServerProps['default_generation_settings']['params']
|
||||
| null {
|
||||
return this._serverProps?.default_generation_settings?.params || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if slots endpoint is available based on server properties and endpoint support
|
||||
*/
|
||||
|
|
@ -273,3 +279,4 @@ export const supportedModalities = () => serverStore.supportedModalities;
|
|||
export const supportsVision = () => serverStore.supportsVision;
|
||||
export const supportsAudio = () => serverStore.supportsAudio;
|
||||
export const slotsEndpointAvailable = () => serverStore.slotsEndpointAvailable;
|
||||
export const serverDefaultParams = () => serverStore.serverDefaultParams;
|
||||
|
|
|
|||
|
|
@ -33,11 +33,25 @@
|
|||
|
||||
import { browser } from '$app/environment';
|
||||
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
|
||||
import { normalizeFloatingPoint } from '$lib/utils/precision';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { setConfigValue, getConfigValue, configToParameterRecord } from '$lib/utils/config-helpers';
|
||||
|
||||
class SettingsStore {
|
||||
config = $state<SettingsConfigType>({ ...SETTING_CONFIG_DEFAULT });
|
||||
theme = $state<string>('auto');
|
||||
isInitialized = $state(false);
|
||||
userOverrides = $state<Set<string>>(new Set());
|
||||
|
||||
/**
|
||||
* Helper method to get server defaults with null safety
|
||||
* Centralizes the pattern of getting and extracting server defaults
|
||||
*/
|
||||
private getServerDefaults(): Record<string, string | number | boolean> {
|
||||
const serverParams = serverStore.serverDefaultParams;
|
||||
return serverParams ? ParameterSyncService.extractServerDefaults(serverParams) : {};
|
||||
}
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
|
|
@ -67,14 +81,20 @@ class SettingsStore {
|
|||
|
||||
try {
|
||||
const savedVal = JSON.parse(localStorage.getItem('config') || '{}');
|
||||
|
||||
// Merge with defaults to prevent breaking changes
|
||||
this.config = {
|
||||
...SETTING_CONFIG_DEFAULT,
|
||||
...savedVal
|
||||
};
|
||||
|
||||
// Load user overrides
|
||||
const savedOverrides = JSON.parse(localStorage.getItem('userOverrides') || '[]');
|
||||
this.userOverrides = new Set(savedOverrides);
|
||||
} catch (error) {
|
||||
console.warn('Failed to parse config from localStorage, using defaults:', error);
|
||||
this.config = { ...SETTING_CONFIG_DEFAULT };
|
||||
this.userOverrides = new Set();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -86,14 +106,30 @@ class SettingsStore {
|
|||
|
||||
this.theme = localStorage.getItem('theme') || 'auto';
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a specific configuration setting
|
||||
* @param key - The configuration key to update
|
||||
* @param value - The new value for the configuration key
|
||||
*/
|
||||
updateConfig<K extends keyof SettingsConfigType>(key: K, value: SettingsConfigType[K]) {
|
||||
updateConfig<K extends keyof SettingsConfigType>(key: K, value: SettingsConfigType[K]): void {
|
||||
this.config[key] = value;
|
||||
|
||||
if (ParameterSyncService.canSyncParameter(key as string)) {
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
const propsDefault = propsDefaults[key as string];
|
||||
|
||||
if (propsDefault !== undefined) {
|
||||
const normalizedValue = normalizeFloatingPoint(value);
|
||||
const normalizedDefault = normalizeFloatingPoint(propsDefault);
|
||||
|
||||
if (normalizedValue === normalizedDefault) {
|
||||
this.userOverrides.delete(key as string);
|
||||
} else {
|
||||
this.userOverrides.add(key as string);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
|
|
@ -103,6 +139,26 @@ class SettingsStore {
|
|||
*/
|
||||
updateMultipleConfig(updates: Partial<SettingsConfigType>) {
|
||||
Object.assign(this.config, updates);
|
||||
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
|
||||
for (const [key, value] of Object.entries(updates)) {
|
||||
if (ParameterSyncService.canSyncParameter(key)) {
|
||||
const propsDefault = propsDefaults[key];
|
||||
|
||||
if (propsDefault !== undefined) {
|
||||
const normalizedValue = normalizeFloatingPoint(value);
|
||||
const normalizedDefault = normalizeFloatingPoint(propsDefault);
|
||||
|
||||
if (normalizedValue === normalizedDefault) {
|
||||
this.userOverrides.delete(key);
|
||||
} else {
|
||||
this.userOverrides.add(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
|
|
@ -114,6 +170,8 @@ class SettingsStore {
|
|||
|
||||
try {
|
||||
localStorage.setItem('config', JSON.stringify(this.config));
|
||||
|
||||
localStorage.setItem('userOverrides', JSON.stringify(Array.from(this.userOverrides)));
|
||||
} catch (error) {
|
||||
console.error('Failed to save config to localStorage:', error);
|
||||
}
|
||||
|
|
@ -185,6 +243,129 @@ class SettingsStore {
|
|||
getAllConfig(): SettingsConfigType {
|
||||
return { ...this.config };
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize settings with props defaults when server properties are first loaded
|
||||
* This sets up the default values from /props endpoint
|
||||
*/
|
||||
syncWithServerDefaults(): void {
|
||||
const serverParams = serverStore.serverDefaultParams;
|
||||
if (!serverParams) {
|
||||
console.warn('No server parameters available for initialization');
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
|
||||
for (const [key, propsValue] of Object.entries(propsDefaults)) {
|
||||
const currentValue = getConfigValue(this.config, key);
|
||||
|
||||
const normalizedCurrent = normalizeFloatingPoint(currentValue);
|
||||
const normalizedDefault = normalizeFloatingPoint(propsValue);
|
||||
|
||||
if (normalizedCurrent === normalizedDefault) {
|
||||
this.userOverrides.delete(key);
|
||||
setConfigValue(this.config, key, propsValue);
|
||||
} else if (!this.userOverrides.has(key)) {
|
||||
setConfigValue(this.config, key, propsValue);
|
||||
}
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
console.log('Settings initialized with props defaults:', propsDefaults);
|
||||
console.log('Current user overrides after sync:', Array.from(this.userOverrides));
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all user overrides (for debugging)
|
||||
*/
|
||||
clearAllUserOverrides(): void {
|
||||
this.userOverrides.clear();
|
||||
this.saveConfig();
|
||||
console.log('Cleared all user overrides');
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset all parameters to their default values (from props)
|
||||
* This is used by the "Reset to Default" functionality
|
||||
* Prioritizes server defaults from /props, falls back to webui defaults
|
||||
*/
|
||||
forceSyncWithServerDefaults(): void {
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
const syncableKeys = ParameterSyncService.getSyncableParameterKeys();
|
||||
|
||||
for (const key of syncableKeys) {
|
||||
if (propsDefaults[key] !== undefined) {
|
||||
const normalizedValue = normalizeFloatingPoint(propsDefaults[key]);
|
||||
|
||||
setConfigValue(this.config, key, normalizedValue);
|
||||
} else {
|
||||
if (key in SETTING_CONFIG_DEFAULT) {
|
||||
const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key);
|
||||
|
||||
setConfigValue(this.config, key, defaultValue);
|
||||
}
|
||||
}
|
||||
|
||||
this.userOverrides.delete(key);
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get parameter information including source for a specific parameter
|
||||
*/
|
||||
getParameterInfo(key: string) {
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
const currentValue = getConfigValue(this.config, key);
|
||||
|
||||
return ParameterSyncService.getParameterInfo(
|
||||
key,
|
||||
currentValue ?? '',
|
||||
propsDefaults,
|
||||
this.userOverrides
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset a parameter to server default (or webui default if no server default)
|
||||
*/
|
||||
resetParameterToServerDefault(key: string): void {
|
||||
const serverDefaults = this.getServerDefaults();
|
||||
|
||||
if (serverDefaults[key] !== undefined) {
|
||||
const value = normalizeFloatingPoint(serverDefaults[key]);
|
||||
|
||||
this.config[key as keyof SettingsConfigType] =
|
||||
value as SettingsConfigType[keyof SettingsConfigType];
|
||||
} else {
|
||||
if (key in SETTING_CONFIG_DEFAULT) {
|
||||
const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key);
|
||||
|
||||
setConfigValue(this.config, key, defaultValue);
|
||||
}
|
||||
}
|
||||
|
||||
this.userOverrides.delete(key);
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get diff between current settings and server defaults
|
||||
*/
|
||||
getParameterDiff() {
|
||||
const serverDefaults = this.getServerDefaults();
|
||||
if (Object.keys(serverDefaults).length === 0) return {};
|
||||
|
||||
const configAsRecord = configToParameterRecord(
|
||||
this.config,
|
||||
ParameterSyncService.getSyncableParameterKeys()
|
||||
);
|
||||
|
||||
return ParameterSyncService.createParameterDiff(configAsRecord, serverDefaults);
|
||||
}
|
||||
}
|
||||
|
||||
// Create and export the settings store instance
|
||||
|
|
@ -204,3 +385,11 @@ export const resetTheme = settingsStore.resetTheme.bind(settingsStore);
|
|||
export const resetAll = settingsStore.resetAll.bind(settingsStore);
|
||||
export const getConfig = settingsStore.getConfig.bind(settingsStore);
|
||||
export const getAllConfig = settingsStore.getAllConfig.bind(settingsStore);
|
||||
export const syncWithServerDefaults = settingsStore.syncWithServerDefaults.bind(settingsStore);
|
||||
export const forceSyncWithServerDefaults =
|
||||
settingsStore.forceSyncWithServerDefaults.bind(settingsStore);
|
||||
export const getParameterInfo = settingsStore.getParameterInfo.bind(settingsStore);
|
||||
export const resetParameterToServerDefault =
|
||||
settingsStore.resetParameterToServerDefault.bind(settingsStore);
|
||||
export const getParameterDiff = settingsStore.getParameterDiff.bind(settingsStore);
|
||||
export const clearAllUserOverrides = settingsStore.clearAllUserOverrides.bind(settingsStore);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Type-safe configuration helpers
|
||||
*
|
||||
* Provides utilities for safely accessing and modifying configuration objects
|
||||
* with dynamic keys while maintaining TypeScript type safety.
|
||||
*/
|
||||
|
||||
import type { SettingsConfigType } from '$lib/types/settings';
|
||||
|
||||
/**
|
||||
* Type-safe helper to access config properties dynamically
|
||||
* Provides better type safety than direct casting to Record
|
||||
*/
|
||||
export function setConfigValue<T extends SettingsConfigType>(
|
||||
config: T,
|
||||
key: string,
|
||||
value: unknown
|
||||
): void {
|
||||
if (key in config) {
|
||||
(config as Record<string, unknown>)[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type-safe helper to get config values dynamically
|
||||
*/
|
||||
export function getConfigValue<T extends SettingsConfigType>(
|
||||
config: T,
|
||||
key: string
|
||||
): string | number | boolean | undefined {
|
||||
const value = (config as Record<string, unknown>)[key];
|
||||
return value as string | number | boolean | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a SettingsConfigType to a ParameterRecord for specific keys
|
||||
* Useful for parameter synchronization operations
|
||||
*/
|
||||
export function configToParameterRecord<T extends SettingsConfigType>(
|
||||
config: T,
|
||||
keys: string[]
|
||||
): Record<string, string | number | boolean> {
|
||||
const record: Record<string, string | number | boolean> = {};
|
||||
|
||||
for (const key of keys) {
|
||||
const value = getConfigValue(config, key);
|
||||
if (value !== undefined) {
|
||||
record[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return record;
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Floating-point precision utilities
|
||||
*
|
||||
* Provides functions to normalize floating-point numbers for consistent comparison
|
||||
* and display, addressing JavaScript's floating-point precision issues.
|
||||
*/
|
||||
|
||||
import { PRECISION_MULTIPLIER } from '$lib/constants/precision';
|
||||
|
||||
/**
|
||||
* Normalize floating-point numbers for consistent comparison
|
||||
* Addresses JavaScript floating-point precision issues (e.g., 0.949999988079071 → 0.95)
|
||||
*/
|
||||
export function normalizeFloatingPoint(value: unknown): unknown {
|
||||
return typeof value === 'number'
|
||||
? Math.round(value * PRECISION_MULTIPLIER) / PRECISION_MULTIPLIER
|
||||
: value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type-safe version that only accepts numbers
|
||||
*/
|
||||
export function normalizeNumber(value: number): number {
|
||||
return Math.round(value * PRECISION_MULTIPLIER) / PRECISION_MULTIPLIER;
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
} from '$lib/stores/chat.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { ModeWatcher } from 'mode-watcher';
|
||||
import { Toaster } from 'svelte-sonner';
|
||||
import { goto } from '$app/navigation';
|
||||
|
|
@ -95,6 +95,15 @@
|
|||
serverStore.fetchServerProps();
|
||||
});
|
||||
|
||||
// Sync settings when server props are loaded
|
||||
$effect(() => {
|
||||
const serverProps = serverStore.serverProps;
|
||||
|
||||
if (serverProps?.default_generation_settings?.params) {
|
||||
settingsStore.syncWithServerDefaults();
|
||||
}
|
||||
});
|
||||
|
||||
// Monitor API key changes and redirect to error page if removed or changed when required
|
||||
$effect(() => {
|
||||
const apiKey = config().apiKey;
|
||||
|
|
|
|||
Loading…
Reference in New Issue