diff --git a/common/arg.cpp b/common/arg.cpp index 039151d026..7865391214 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -873,7 +873,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex sampler_type_chars += common_sampler_type_to_chr(sampler); sampler_type_names += common_sampler_type_to_str(sampler) + ";"; } - sampler_type_names.pop_back(); + if (!sampler_type_names.empty()) { + sampler_type_names.pop_back(); // remove last semicolon + } /** diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 806b3d7b47..86fe0b5f17 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -189,10 +189,10 @@ class ModelBase: return tensors prefix = "model" if not self.is_mistral_format else "consolidated" - part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")) + part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") is_safetensors: bool = len(part_names) > 0 if not is_safetensors: - part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")) + part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") tensor_names_from_index: set[str] = set() @@ -209,7 +209,8 @@ class ModelBase: if weight_map is None or not isinstance(weight_map, dict): raise ValueError(f"Can't load 'weight_map' from {index_name!r}") tensor_names_from_index.update(weight_map.keys()) - part_names |= set(weight_map.values()) + part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) + part_names = sorted(part_dict.keys()) else: weight_map = {} else: diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index fc31089f3e..28fb7612e5 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") endif() + if (GGML_XTHEADVECTOR) string(APPEND MARCH_STR "_xtheadvector") elseif (GGML_RVV) @@ -465,6 +466,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZVFH) string(APPEND MARCH_STR "_zvfh") endif() + if (GGML_RV_ZVFBFWMA) + string(APPEND MARCH_STR "_zvfbfwma") + endif() endif() if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a59b518938..f7ba1fe317 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3320,13 +3320,33 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { __m128 y_vec = _mm_cvtph_ps(x_vec); _mm_storeu_ps(y + i, y_vec); } -#elif defined(__riscv_zvfh) - for (int vl; i < n; i += vl) { - vl = __riscv_vsetvl_e16m1(n - i); - vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl); - vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl); - __riscv_vse32_v_f32m2(&y[i], vy, vl); + +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin) + // calculate step size + const int epr = __riscv_vsetvlmax_e16m2(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (; i < np; i += step) { + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr); + vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr); + __riscv_vse32_v_f32m4(y + i, ay0, epr); + + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr); + vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr); + __riscv_vse32_v_f32m4(y + i + epr, ay1, epr); } + + // leftovers + int vl; + for (i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl); + vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl); + __riscv_vse32_v_f32m4(y + i, ay0, vl); + } + #endif for (; i < n; ++i) { @@ -3371,6 +3391,31 @@ void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { (const __m128i *)(x + i))), 16))); } +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin) + // calculate step size + const int epr = __riscv_vsetvlmax_e16m2(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (; i < np; i += step) { + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr); + vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr); + __riscv_vse32_v_f32m4(y + i, ay0, epr); + + vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr); + vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr); + __riscv_vse32_v_f32m4(y + i + epr, ay1, epr); + } + + // leftovers + int vl; + for (i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl); + vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl); + __riscv_vse32_v_f32m4(y + i, ay0, vl); + } #endif for (; i < n; i++) { y[i] = GGML_BF16_TO_FP32(x[i]); diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index ac8633e212..427e63245b 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -195,8 +195,48 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * sumf += (ggml_float)_mm_cvtss_f32(g); #undef LOAD -#endif +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma) + size_t vl = __riscv_vsetvlmax_e32m4(); + // initialize accumulators to all zeroes + vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (; i < np; i += step) { + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr); + vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr); + vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr); + vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr); + vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // accumulate in 1 register + vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl); + + // leftovers + for (i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl); + vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl); + vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl); + } + + // reduce + vl = __riscv_vsetvlmax_e32m4(); + vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); + +#endif for (; i < n; ++i) { sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * GGML_BF16_TO_FP32(y[i])); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index bd80805fdc..3198b33b50 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -224,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); - } - } + + #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 along the row dimension + for (int i = 0; i < np; i += step) { + vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); + vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); + vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); + vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); + vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); + + vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); + vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); + vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); + vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); + vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); + } + + vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); + vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); + + // leftovers + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); + + vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); + vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); + } + + // reduce + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), + __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), + __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); + vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( + acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), + __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), + __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); + vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( + acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); + sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + #else const int np = (n & ~(GGML_F16_STEP - 1)); @@ -475,15 +533,39 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, } np = n; #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic - const int np = n; - _Float16 hv = (_Float16)v; - for (int i = 0, avl; i < n; i += avl) { - avl = __riscv_vsetvl_e16m8(n - i); - vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl); - vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl); - vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl); - __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl); + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); + + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -724,13 +806,34 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float svst1_f16(pg, (__fp16 *)(y + np), out); } #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - for (int i = 0, vl; i < n; i += vl) { - vl = __riscv_vsetvl_e16m2(n - i); - vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl); - vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl); - vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl); - vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl); - __riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl); + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); + + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); } #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 794d90bdd1..3268dadfe8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -78,27 +78,25 @@ namespace ggml_cuda_mma { // MIRRORED == Each data value is held exactly once per thread subgroup. DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA. DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. - DATA_LAYOUT_I_MAJOR_MIRRORED = 20, + DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, - DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - constexpr bool is_i_major(const data_layout dl) { + static constexpr bool is_i_major(const data_layout dl) { return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED || - dl == DATA_LAYOUT_I_MAJOR_DUAL; + dl == DATA_LAYOUT_I_MAJOR_MIRRORED; } - constexpr data_layout get_input_data_layout() { -#if defined(RDNA3) - return DATA_LAYOUT_I_MAJOR_DUAL; + static constexpr __device__ data_layout get_input_data_layout() { +#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) +#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } template @@ -462,11 +460,65 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + + // RDNA3 + static constexpr int ne = I * J / 32 * 2; + + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (supported()) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (supported()) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + template struct tile { static constexpr int I = I_; static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; +#if defined(RDNA3) + static constexpr int ne = tile::ne; + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } +#else // Volta static constexpr int ne = I * J / (WARP_SIZE/4); half2 x[ne] = {{0.0f, 0.0f}}; @@ -493,6 +545,29 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(RDNA3) + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = tile::ne; + + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } }; template @@ -528,42 +603,6 @@ namespace ggml_cuda_mma { } }; - template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL; - - static constexpr int ne = I * J / 32 * 2; - - T x[ne] = {0}; - - static constexpr __device__ bool supported() { - if (I == 16 && J == 16) return true; - if (I == 16 && J == 8) return true; - if (I == 16 && J == 4) return true; - return false; - } - - static __device__ __forceinline__ int get_i(const int l) { - if constexpr (supported()) { - return threadIdx.x % 16; - } else { - NO_DEVICE_CODE; - return -1; - } - } - - static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { - return l; - } else { - NO_DEVICE_CODE; - return -1; - } - } - }; - #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 7907e706d5..4918ae971a 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -288,7 +288,7 @@ class LocalTensor: data_range: LocalTensorRange def mmap_bytes(self) -> np.ndarray: - return np.memmap(self.data_range.filename, mode='r', offset=self.data_range.offset, shape=self.data_range.size) + return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size) class SafetensorsLocal: diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 0641c2d22f..23b648a2e3 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -13,9 +13,10 @@ #ifdef __has_include #if __has_include() #include + #include + #include #if defined(_POSIX_MAPPED_FILES) #include - #include #endif #if defined(_POSIX_MEMLOCK_RANGE) #include @@ -74,7 +75,7 @@ struct llama_file::impl { return ret; } - impl(const char * fname, const char * mode) { + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) { fp = ggml_fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -153,13 +154,40 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + void read_aligned_chunk(size_t offset, void * dest, size_t size) const { + throw std::runtime_error("DirectIO is not implemented on Windows."); + } + ~impl() { if (fp) { std::fclose(fp); } } #else - impl(const char * fname, const char * mode) { + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) { +#ifdef __linux__ + // Try unbuffered I/O for read only + if (use_direct_io && std::strcmp(mode, "rb") == 0) { + fd = open(fname, O_RDONLY | O_DIRECT); + + if (fd != -1) { + struct stat file_stats{}; + fstat(fd, &file_stats); + + size = file_stats.st_size; + alignment = file_stats.st_blksize; + + off_t ret = lseek(fd, 0, SEEK_SET); + if (ret == -1) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + return; + } + + LLAMA_LOG_WARN("Failed to open model %s with error: %s. Falling back to buffered I/O", + fname, strerror(errno)); + } +#endif fp = ggml_fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -170,27 +198,30 @@ struct llama_file::impl { } size_t tell() const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - if (ret == -1) { - throw std::runtime_error(format("ftell error: %s", strerror(errno))); + if (fd == -1) { + long ret = std::ftell(fp); + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; } - return (size_t) ret; + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos == -1) { + throw std::runtime_error(format("lseek error: %s", strerror(errno))); + } + return (size_t) pos; } void seek(size_t offset, int whence) const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - if (ret != 0) { + off_t ret = 0; + if (fd == -1) { + ret = std::fseek(fp, (long) offset, whence); + } else { + ret = lseek(fd, offset, whence); + } + if (ret == -1) { throw std::runtime_error(format("seek error: %s", strerror(errno))); } } @@ -200,13 +231,55 @@ struct llama_file::impl { return; } errno = 0; - std::size_t ret = std::fread(ptr, len, 1, fp); - if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); + if (fd == -1) { + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } else { + bool successful = false; + while (!successful) { + off_t ret = read(fd, ptr, len); + + if (ret == -1) { + if (errno == EINTR) { + continue; // Interrupted by signal, retry + } + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + successful = true; + } } - if (ret != 1) { - throw std::runtime_error("unexpectedly reached end of file"); + } + + void read_aligned_chunk(size_t offset, void * dest, size_t size) const { + off_t aligned_offset = offset & ~(alignment - 1); + off_t offset_from_alignment = offset - aligned_offset; + size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1); + + void * raw_buffer = nullptr; + int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read); + if (ret != 0) { + throw std::runtime_error(format("posix_memalign failed with error %d", ret)); } + + struct aligned_buffer_deleter { + void operator()(void * p) const { free(p); } + }; + std::unique_ptr buffer(raw_buffer); + + seek(aligned_offset, SEEK_SET); + read_raw(buffer.get(), bytes_to_read); + + uintptr_t actual_data = reinterpret_cast(buffer.get()) + offset_from_alignment; + memcpy(dest, reinterpret_cast(actual_data), size); } uint32_t read_u32() const { @@ -231,22 +304,43 @@ struct llama_file::impl { } ~impl() { - if (fp) { + if (fd != -1) { + close(fd); + } else { std::fclose(fp); } } + int fd = -1; #endif - FILE * fp; - size_t size; + void read_raw_at(void * ptr, size_t len, size_t offset) const { + if (alignment != 1) { + read_aligned_chunk(offset, ptr, len); + } else { + seek(offset, SEEK_SET); + read_raw(ptr, len); + } + } + + size_t read_alignment() const { + return alignment; + } + + size_t alignment = 1; + + FILE * fp{}; + size_t size{}; }; -llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : + pimpl(std::make_unique(fname, mode, use_direct_io)) {} llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } +size_t llama_file::read_alignment() const { return pimpl->read_alignment(); } + int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); @@ -261,6 +355,7 @@ int llama_file::file_id() const { void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } +void llama_file::read_raw_at(void * ptr, size_t len, size_t offset) const { pimpl->read_raw_at(ptr, len, offset); } uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } diff --git a/src/llama-mmap.h b/src/llama-mmap.h index 4e5aec3f44..729aac164b 100644 --- a/src/llama-mmap.h +++ b/src/llama-mmap.h @@ -3,6 +3,7 @@ #include #include #include +#include struct llama_file; struct llama_mmap; @@ -13,7 +14,7 @@ using llama_mmaps = std::vector>; using llama_mlocks = std::vector>; struct llama_file { - llama_file(const char * fname, const char * mode); + llama_file(const char * fname, const char * mode, bool use_direct_io = false); ~llama_file(); size_t tell() const; @@ -24,11 +25,14 @@ struct llama_file { void seek(size_t offset, int whence) const; void read_raw(void * ptr, size_t len) const; + void read_raw_at(void * ptr, size_t len, size_t offset) const; + void read_aligned_chunk(size_t offset, void * dest, size_t size) const; uint32_t read_u32() const; void write_raw(const void * ptr, size_t len) const; void write_u32(uint32_t val) const; + size_t read_alignment() const; private: struct impl; std::unique_ptr pimpl; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index ca2ea2461d..1da89515f7 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -504,7 +504,7 @@ llama_model_loader::llama_model_loader( get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - files.emplace_back(new llama_file(fname.c_str(), "rb")); + files.emplace_back(new llama_file(fname.c_str(), "rb", !use_mmap)); contexts.emplace_back(ctx); // Save tensors data offset of the main file. @@ -572,7 +572,7 @@ llama_model_loader::llama_model_loader( } } - files.emplace_back(new llama_file(fname_split, "rb")); + files.emplace_back(new llama_file(fname_split, "rb", !use_mmap)); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -935,7 +935,15 @@ bool llama_model_loader::load_all_data( // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. // NVMe raid configurations might require more / larger buffers. constexpr size_t n_buffers = 4; - constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + size_t alignment = 1; + for (const auto & file : files) { + alignment = std::max(file->read_alignment(), alignment); + } + + // Buffer size: balance between memory usage and I/O efficiency + // 64MB works well for NVMe drives + const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024; std::vector host_buffers; std::vector events; @@ -985,6 +993,7 @@ bool llama_model_loader::load_all_data( // If the backend is supported, create pinned memory buffers and events for synchronisation. for (size_t idx = 0; idx < n_buffers; ++idx) { auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, ggml_backend_dev_name(dev)); @@ -1066,9 +1075,9 @@ bool llama_model_loader::load_all_data( } } else { const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { - file->seek(weight->offs, SEEK_SET); - file->read_raw(cur->data, n_size); + file->read_raw_at(cur->data, n_size, weight->offs); if (check_tensors) { validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); @@ -1077,26 +1086,60 @@ bool llama_model_loader::load_all_data( } else { // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. if (upload_backend) { - file->seek(weight->offs, SEEK_SET); + auto offset = (off_t) weight->offs; + alignment = file->read_alignment(); + off_t aligned_offset = offset & ~(alignment - 1); + off_t offset_from_alignment = offset - aligned_offset; + file->seek(aligned_offset, SEEK_SET); + + // Calculate aligned read boundaries + size_t read_start = aligned_offset; + size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1); size_t bytes_read = 0; + size_t data_read = 0; // Actual tensor data copied (excluding padding) - while (bytes_read < n_size) { - size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + while (bytes_read < read_end - read_start) { + size_t read_size = std::min(buffer_size, read_end - read_start - bytes_read); + // Align the destination pointer within the pinned buffer + uintptr_t ptr_dest_aligned = (reinterpret_cast(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1); + + // Wait for previous upload to complete before reusing buffer ggml_backend_event_synchronize(events[buffer_idx]); - file->read_raw(host_ptrs[buffer_idx], read_iteration); - ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + + // Read aligned chunk from file + file->read_raw(reinterpret_cast(ptr_dest_aligned), read_size); + + // Calculate actual data portion (excluding alignment padding) + uintptr_t ptr_data = ptr_dest_aligned; + size_t data_to_copy = read_size; + + // Skip alignment padding at start of first chunk + if (bytes_read == 0) { + ptr_data += offset_from_alignment; + data_to_copy -= offset_from_alignment; + } + + // Trim alignment padding at end of last chunk + if (aligned_offset + bytes_read + read_size > offset + n_size) { + data_to_copy -= (read_end - (offset + n_size)); + } + + // Async upload actual data to GPU + ggml_backend_tensor_set_async(upload_backend, cur, + reinterpret_cast(ptr_data), data_read, data_to_copy); ggml_backend_event_record(events[buffer_idx], upload_backend); - bytes_read += read_iteration; + data_read += data_to_copy; + bytes_read += read_size; + ++buffer_idx; buffer_idx %= n_buffers; } } else { read_buf.resize(n_size); - file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), n_size); + file->read_raw_at(read_buf.data(), n_size, weight->offs); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c9a3c5dfa2..d2270e8f2d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2378,10 +2378,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (cpu_dev == nullptr) { throw std::runtime_error(format("%s: no CPU backend found", __func__)); } - const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il); + const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; @@ -6693,10 +6693,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (llama_supports_gpu_offload()) { const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); - if (n_gpu_layers > (int) hparams.n_layer) { + int n_repeating = n_gpu; + if (n_repeating > 0) { LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + n_repeating--; } + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; diff --git a/src/llama.cpp b/src/llama.cpp index 708d879bc0..1e18637e36 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -292,10 +292,6 @@ static void llama_params_fit_impl( if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); } - if (hp_ngl < 2*nd) { - throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least " - + std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort"); - } } if (!tensor_buft_overrides) { throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort"); @@ -362,8 +358,7 @@ static void llama_params_fit_impl( auto set_ngl_tensor_split_tbo = [&]( const std::vector & ngl_per_device, const std::vector & overflow_bufts, - llama_model_params & mparams, - const bool add_nonrepeating) { + llama_model_params & mparams) { mparams.n_gpu_layers = 0; for (size_t id = 0; id < nd; id++) { mparams.n_gpu_layers += ngl_per_device[id].n_layer; @@ -371,13 +366,9 @@ static void llama_params_fit_impl( tensor_split[id] = ngl_per_device[id].n_layer; } } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl); - uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides + assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); + uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - if (add_nonrepeating) { - mparams.n_gpu_layers += 1; - tensor_split[nd - 1] += 1; - } mparams.tensor_split = tensor_split; size_t itbo = 0; @@ -408,10 +399,9 @@ static void llama_params_fit_impl( auto get_memory_for_layers = [&]( const char * func_name, const std::vector & ngl_per_device, - const std::vector & overflow_bufts, - const bool add_nonrepeating) -> std::vector { + const std::vector & overflow_bufts) -> std::vector { llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating); + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); const dmds_t dmd_nl = llama_get_device_memory_data( path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); @@ -469,9 +459,6 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } - // whether for the optimal memory use we expect to load at least some MoE tensors: - const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0; - std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: overflow_bufts.reserve(nd); for (size_t id = 0; id < nd - 1; ++id) { @@ -480,7 +467,7 @@ static void llama_params_fit_impl( overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); std::vector ngl_per_device(nd); - std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe); + std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); if (hp_nex > 0) { for (size_t id = 0; id < nd; id++) { ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; @@ -493,13 +480,14 @@ static void llama_params_fit_impl( // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target // - check memory use of our guess, replace either the low or high bound // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits + // - the last device has the output layer, which cannot be a partial layer if (hp_nex == 0) { LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); } else { LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); } for (int id = nd - 1; id >= 0; id--) { - uint32_t n_unassigned = hp_ngl; + uint32_t n_unassigned = hp_ngl + 1; for (size_t jd = id + 1; jd < nd; ++jd) { assert(n_unassigned >= ngl_per_device[jd].n_layer); n_unassigned -= ngl_per_device[jd].n_layer; @@ -508,10 +496,10 @@ static void llama_params_fit_impl( std::vector ngl_per_device_high = ngl_per_device; ngl_per_device_high[id].n_layer = n_unassigned; if (hp_nex > 0) { - ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer; + ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; } if (ngl_per_device_high[id].n_layer > 0) { - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe); + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; @@ -526,7 +514,7 @@ static void llama_params_fit_impl( if (hp_nex) { ngl_per_device_test[id].n_part += step_size; } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] <= targets[id]) { ngl_per_device = ngl_per_device_test; @@ -553,7 +541,7 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); } if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe); + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); return; } @@ -576,13 +564,13 @@ static void llama_params_fit_impl( for (size_t id = 0; id <= id_dense_start; id++) { std::vector ngl_per_device_high = ngl_per_device; for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer; + const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; ngl_per_device_high[id].n_layer += n_layer_move; ngl_per_device_high[jd].n_layer -= n_layer_move; ngl_per_device_high[jd].n_part = 0; } size_t id_dense_start_high = nd - 1; - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe); + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); @@ -610,7 +598,7 @@ static void llama_params_fit_impl( break; } } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] <= targets[id]) { ngl_per_device = ngl_per_device_test; @@ -637,7 +625,7 @@ static void llama_params_fit_impl( } // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > 0) { + if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { std::vector ngl_per_device_test = ngl_per_device; size_t id_dense_start_test = id_dense_start; ngl_per_device_test[id_dense_start_test].n_layer--; @@ -649,7 +637,7 @@ static void llama_params_fit_impl( } ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] < targets[id]) { ngl_per_device = ngl_per_device_test; mem = mem_test; @@ -659,7 +647,7 @@ static void llama_params_fit_impl( ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] < targets[id]) { ngl_per_device = ngl_per_device_test; mem = mem_test; @@ -670,7 +658,7 @@ static void llama_params_fit_impl( } else { ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] < targets[id]) { ngl_per_device = ngl_per_device_test; mem = mem_test; @@ -687,7 +675,7 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); } - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe); + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); } bool llama_params_fit( diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index d6cc23ebfc..9e44f03260 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 4f37b308b1..0d1a03aca3 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -2109,9 +2109,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.48.5", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz", - "integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==", + "version": "2.49.2", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.49.2.tgz", + "integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==", "dev": true, "license": "MIT", "dependencies": { @@ -5797,9 +5797,9 @@ } }, "node_modules/mdast-util-to-hast": { - "version": "13.2.0", - "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.0.tgz", - "integrity": "sha512-QGYKEuUsYT9ykKBCMOEDLsU5JRObWQusAolFMeko/tYPufNkRffBAQjIE+99jbA87xv6FgmjLtwjh9wBWajwAA==", + "version": "13.2.1", + "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.1.tgz", + "integrity": "sha512-cctsq2wp5vTsLIcaymblUriiTcZd0CwWtCbLvrOzYCDZoWyMNV8sZ7krj09FSnsiJi3WVsHLM4k6Dq/yaPyCXA==", "license": "MIT", "dependencies": { "@types/hast": "^3.0.0", diff --git a/tools/server/webui/src/app.d.ts b/tools/server/webui/src/app.d.ts index 71976936ed..73287d91b6 100644 --- a/tools/server/webui/src/app.d.ts +++ b/tools/server/webui/src/app.d.ts @@ -124,3 +124,10 @@ declare global { SettingsConfigType }; } + +declare global { + interface Window { + idxThemeStyle?: number; + idxCodeBlock?: number; + } +} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte index 2c9a012eff..8997963f16 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte @@ -244,7 +244,7 @@
{#if displayedModel()} - +
{#if isRouter} {/if} - +
{/if} {#if config().showToolCalls} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte index a453a31010..a39acb1d75 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte @@ -1,20 +1,122 @@ - +
+
+ {#if hasPromptStats} + + + + + +

Reading (prompt processing)

+
+
+ {/if} + + + + + +

Generation (token output)

+
+
+
- - - +
+ {#if activeView === ChatMessageStatsView.GENERATION} + + + + {:else if hasPromptStats} + + + + {/if} +
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 57a2edac58..ae40b35d33 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -587,7 +587,7 @@ &::after { content: ''; - position: fixed; + position: absolute; bottom: 0; z-index: -1; left: 0; diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte b/tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte index 9e5339cab5..a2b28d2057 100644 --- a/tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte +++ b/tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte @@ -1,5 +1,6 @@ - - {#snippet icon()} - - {/snippet} +{#if tooltipLabel} + + + + {#snippet icon()} + + {/snippet} - {value} - + {value} + + + +

{tooltipLabel}

+
+ +{:else} + + {#snippet icon()} + + {/snippet} + + {value} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 2a4a39535e..cb3ae17a63 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -7,15 +7,19 @@ import remarkRehype from 'remark-rehype'; import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; - import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils'; - import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; + import type { Root as HastRoot, RootContent as HastRootContent } from 'hast'; + import type { Root as MdastRoot } from 'mdast'; import { browser } from '$app/environment'; + import { onDestroy, tick } from 'svelte'; + import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; + import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links'; + import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks'; + import { remarkLiteralHtml } from '$lib/markdown/literal-html'; + import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils'; import '$styles/katex-custom.scss'; - import githubDarkCss from 'highlight.js/styles/github-dark.css?inline'; import githubLightCss from 'highlight.js/styles/github.css?inline'; import { mode } from 'mode-watcher'; - import { remarkLiteralHtml } from '$lib/markdown/literal-html'; import CodePreviewDialog from './CodePreviewDialog.svelte'; interface Props { @@ -23,33 +27,24 @@ class?: string; } + interface MarkdownBlock { + id: string; + html: string; + } + let { content, class: className = '' }: Props = $props(); let containerRef = $state(); - let processedHtml = $state(''); + let renderedBlocks = $state([]); + let unstableBlockHtml = $state(''); let previewDialogOpen = $state(false); let previewCode = $state(''); let previewLanguage = $state('text'); - function loadHighlightTheme(isDark: boolean) { - if (!browser) return; + let pendingMarkdown: string | null = null; + let isProcessing = false; - const existingThemes = document.querySelectorAll('style[data-highlight-theme]'); - existingThemes.forEach((style) => style.remove()); - - const style = document.createElement('style'); - style.setAttribute('data-highlight-theme', 'true'); - style.textContent = isDark ? githubDarkCss : githubLightCss; - - document.head.appendChild(style); - } - - $effect(() => { - const currentMode = mode.current; - const isDark = currentMode === 'dark'; - - loadHighlightTheme(isDark); - }); + const themeStyleId = `highlight-theme-${(window.idxThemeStyle = (window.idxThemeStyle ?? 0) + 1)}`; let processor = $derived(() => { return remark() @@ -61,139 +56,64 @@ .use(rehypeKatex) // Render math using KaTeX .use(rehypeHighlight) // Add syntax highlighting .use(rehypeRestoreTableHtml) // Restore limited HTML (e.g.,
,
    ) inside Markdown tables - .use(rehypeStringify); // Convert to HTML string + .use(rehypeEnhanceLinks) // Add target="_blank" to links + .use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions + .use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string }); - function enhanceLinks(html: string): string { - if (!html.includes('('.copy-code-btn'); + const previewButtons = containerRef.querySelectorAll('.preview-code-btn'); + + for (const button of copyButtons) { + button.removeEventListener('click', handleCopyClick); } - const tempDiv = document.createElement('div'); - tempDiv.innerHTML = html; - - // Make all links open in new tabs - const linkElements = tempDiv.querySelectorAll('a[href]'); - let mutated = false; - - for (const link of linkElements) { - const target = link.getAttribute('target'); - const rel = link.getAttribute('rel'); - - if (target !== '_blank' || rel !== 'noopener noreferrer') { - mutated = true; - } - - link.setAttribute('target', '_blank'); - link.setAttribute('rel', 'noopener noreferrer'); - } - - return mutated ? tempDiv.innerHTML : html; - } - - function enhanceCodeBlocks(html: string): string { - if (!html.includes(' - `; - - const actions = document.createElement('div'); - actions.className = 'code-block-actions'; - - actions.appendChild(copyButton); - - if (language.toLowerCase() === 'html') { - const previewButton = document.createElement('button'); - previewButton.className = 'preview-code-btn'; - previewButton.setAttribute('data-code-id', codeId); - previewButton.setAttribute('title', 'Preview code'); - previewButton.setAttribute('type', 'button'); - - previewButton.innerHTML = ` - - `; - - actions.appendChild(previewButton); - } - - header.appendChild(languageLabel); - header.appendChild(actions); - wrapper.appendChild(header); - - const clonedPre = pre.cloneNode(true) as HTMLElement; - wrapper.appendChild(clonedPre); - - pre.parentNode?.replaceChild(wrapper, pre); - } - - return mutated ? tempDiv.innerHTML : html; - } - - async function processMarkdown(text: string): Promise { - try { - let normalized = preprocessLaTeX(text); - const result = await processor().process(normalized); - const html = String(result); - const enhancedLinks = enhanceLinks(html); - - return enhanceCodeBlocks(enhancedLinks); - } catch (error) { - console.error('Markdown processing error:', error); - - // Fallback to plain text with line breaks - return text.replace(/\n/g, '
    '); + for (const button of previewButtons) { + button.removeEventListener('click', handlePreviewClick); } } + /** + * Removes this component's highlight.js theme style from the document head. + * Called on component destroy to clean up injected styles. + */ + function cleanupHighlightTheme() { + if (!browser) return; + + const existingTheme = document.getElementById(themeStyleId); + existingTheme?.remove(); + } + + /** + * Loads the appropriate highlight.js theme based on dark/light mode. + * Injects a scoped style element into the document head. + * @param isDark - Whether to load the dark theme (true) or light theme (false) + */ + function loadHighlightTheme(isDark: boolean) { + if (!browser) return; + + const existingTheme = document.getElementById(themeStyleId); + existingTheme?.remove(); + + const style = document.createElement('style'); + style.id = themeStyleId; + style.textContent = isDark ? githubDarkCss : githubLightCss; + + document.head.appendChild(style); + } + + /** + * Extracts code information from a button click target within a code block. + * @param target - The clicked button element + * @returns Object with rawCode and language, or null if extraction fails + */ function getCodeInfoFromTarget(target: HTMLElement) { const wrapper = target.closest('.code-block-wrapper'); @@ -209,12 +129,7 @@ return null; } - const rawCode = codeElement.getAttribute('data-raw-code'); - - if (rawCode === null) { - console.error('No raw code found'); - return null; - } + const rawCode = codeElement.textContent ?? ''; const languageLabel = wrapper.querySelector('.code-language'); const language = languageLabel?.textContent?.trim() || 'text'; @@ -222,6 +137,28 @@ return { rawCode, language }; } + /** + * Generates a unique identifier for a HAST node based on its position. + * Used for stable block identification during incremental rendering. + * @param node - The HAST root content node + * @param indexFallback - Fallback index if position is unavailable + * @returns Unique string identifier for the node + */ + function getHastNodeId(node: HastRootContent, indexFallback: number): string { + const position = node.position; + + if (position?.start?.offset != null && position?.end?.offset != null) { + return `hast-${position.start.offset}-${position.end.offset}`; + } + + return `${node.type}-${indexFallback}`; + } + + /** + * Handles click events on copy buttons within code blocks. + * Copies the raw code content to the clipboard. + * @param event - The click event from the copy button + */ async function handleCopyClick(event: Event) { event.preventDefault(); event.stopPropagation(); @@ -245,6 +182,25 @@ } } + /** + * Handles preview dialog open state changes. + * Clears preview content when dialog is closed. + * @param open - Whether the dialog is being opened or closed + */ + function handlePreviewDialogOpenChange(open: boolean) { + previewDialogOpen = open; + + if (!open) { + previewCode = ''; + previewLanguage = 'text'; + } + } + + /** + * Handles click events on preview buttons within HTML code blocks. + * Opens a preview dialog with the rendered HTML content. + * @param event - The click event from the preview button + */ function handlePreviewClick(event: Event) { event.preventDefault(); event.stopPropagation(); @@ -266,6 +222,61 @@ previewDialogOpen = true; } + /** + * Processes markdown content into stable and unstable HTML blocks. + * Uses incremental rendering: stable blocks are cached, unstable block is re-rendered. + * @param markdown - The raw markdown string to process + */ + async function processMarkdown(markdown: string) { + if (!markdown) { + renderedBlocks = []; + unstableBlockHtml = ''; + return; + } + + const normalized = preprocessLaTeX(markdown); + const processorInstance = processor(); + const ast = processorInstance.parse(normalized) as MdastRoot; + const processedRoot = (await processorInstance.run(ast)) as HastRoot; + const processedChildren = processedRoot.children ?? []; + const stableCount = Math.max(processedChildren.length - 1, 0); + const nextBlocks: MarkdownBlock[] = []; + + for (let index = 0; index < stableCount; index++) { + const hastChild = processedChildren[index]; + const id = getHastNodeId(hastChild, index); + const existing = renderedBlocks[index]; + + if (existing && existing.id === id) { + nextBlocks.push(existing); + continue; + } + + const html = stringifyProcessedNode( + processorInstance, + processedRoot, + processedChildren[index] + ); + + nextBlocks.push({ id, html }); + } + + let unstableHtml = ''; + + if (processedChildren.length > stableCount) { + const unstableChild = processedChildren[stableCount]; + unstableHtml = stringifyProcessedNode(processorInstance, processedRoot, unstableChild); + } + + renderedBlocks = nextBlocks; + await tick(); // Force DOM sync before updating unstable HTML block + unstableBlockHtml = unstableHtml; + } + + /** + * Attaches click event listeners to copy and preview buttons in code blocks. + * Uses data-listener-bound attribute to prevent duplicate bindings. + */ function setupCodeBlockActions() { if (!containerRef) return; @@ -287,40 +298,97 @@ } } - function handlePreviewDialogOpenChange(open: boolean) { - previewDialogOpen = open; + /** + * Converts a single HAST node to an enhanced HTML string. + * Applies link and code block enhancements to the output. + * @param processorInstance - The remark/rehype processor instance + * @param processedRoot - The full processed HAST root (for context) + * @param child - The specific HAST child node to stringify + * @returns Enhanced HTML string representation of the node + */ + function stringifyProcessedNode( + processorInstance: ReturnType, + processedRoot: HastRoot, + child: unknown + ) { + const root: HastRoot = { + ...(processedRoot as HastRoot), + children: [child as never] + }; - if (!open) { - previewCode = ''; - previewLanguage = 'text'; + return processorInstance.stringify(root); + } + + /** + * Queues markdown for processing with coalescing support. + * Only processes the latest markdown when multiple updates arrive quickly. + * @param markdown - The markdown content to render + */ + async function updateRenderedBlocks(markdown: string) { + pendingMarkdown = markdown; + + if (isProcessing) { + return; + } + + isProcessing = true; + + try { + while (pendingMarkdown !== null) { + const nextMarkdown = pendingMarkdown; + pendingMarkdown = null; + + await processMarkdown(nextMarkdown); + } + } catch (error) { + console.error('Failed to process markdown:', error); + renderedBlocks = []; + unstableBlockHtml = markdown.replace(/\n/g, '
    '); + } finally { + isProcessing = false; } } $effect(() => { - if (content) { - processMarkdown(content) - .then((result) => { - processedHtml = result; - }) - .catch((error) => { - console.error('Failed to process markdown:', error); - processedHtml = content.replace(/\n/g, '
    '); - }); - } else { - processedHtml = ''; - } + const currentMode = mode.current; + const isDark = currentMode === 'dark'; + + loadHighlightTheme(isDark); }); $effect(() => { - if (containerRef && processedHtml) { + updateRenderedBlocks(content); + }); + + $effect(() => { + const hasRenderedBlocks = renderedBlocks.length > 0; + const hasUnstableBlock = Boolean(unstableBlockHtml); + + if ((hasRenderedBlocks || hasUnstableBlock) && containerRef) { setupCodeBlockActions(); } }); + + onDestroy(() => { + cleanupEventListeners(); + cleanupHighlightTheme(); + });
    - - {@html processedHtml} + {#each renderedBlocks as block (block.id)} +
    + + {@html block.html} +
    + {/each} + + {#if unstableBlockHtml} +
    + + {@html unstableBlockHtml} +
    + {/if}