diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 996f34ed82..fc26289aec 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,14 +30,19 @@ Before submitting your PR: - Search for existing PRs to prevent duplicating efforts - llama.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier - Test your changes: - - Execute [the full CI locally on your machine](ci/README.md) before publishing - - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) - - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) - - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` + - Execute [the full CI locally on your machine](ci/README.md) before publishing + - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) + - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) + - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` - Create separate PRs for each feature or fix: - - Avoid combining unrelated changes in a single PR - - For intricate features, consider opening a feature request first to discuss and align expectations - - When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs + - Avoid combining unrelated changes in a single PR + - For intricate features, consider opening a feature request first to discuss and align expectations + - When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs + - In particular, adding new data types (extension of the `ggml_type` enum) carries with it a disproportionate maintenance burden. As such, to add a new quantization type you will need to meet the following *additional* criteria *at minimum*: + - convert a small model to GGUF using the new type and upload it to HuggingFace + - provide [perplexity](https://github.com/ggml-org/llama.cpp/tree/master/tools/perplexity) comparisons to FP16/BF16 (whichever is the native precision) as well as to types of similar size + - provide KL divergence data calculated vs. the FP16/BF16 (whichever is the native precision) version for both the new type as well as types of similar size + - provide [performance data](https://github.com/ggml-org/llama.cpp/tree/master/tools/llama-bench) for the new type in comparison to types of similar size on pure CPU - Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly - If you are a new contributor, limit your open PRs to 1. diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 37834c78b8..eec0ea14e3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2194,6 +2194,8 @@ class GPTNeoXModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + assert n_head is not None + assert n_embed is not None if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name): # Map bloom-style qkv_linear to gpt-style qkv_linear @@ -2231,6 +2233,8 @@ class BloomModel(TextModel): def set_gguf_parameters(self): n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + assert n_head is not None + assert n_embed is not None self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(4 * n_embed) @@ -2243,6 +2247,8 @@ class BloomModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + assert n_head is not None + assert n_embed is not None name = re.sub(r'transformer\.', '', name) @@ -3853,6 +3859,7 @@ class LLaDAModel(TextModel): if (rope_dim := hparams.get("head_dim")) is None: n_heads = hparams.get("num_attention_heads", hparams.get("n_heads")) + assert n_heads is not None rope_dim = hparams.get("hidden_size", hparams.get("d_model")) // n_heads self.gguf_writer.add_rope_dimension_count(rope_dim) @@ -3884,6 +3891,7 @@ class LLaDAModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams.get("num_attention_heads", self.hparams.get("n_heads")) + assert n_head is not None n_kv_head = self.hparams.get("num_key_value_heads", self.hparams.get("n_kv_heads")) if self.undo_permute: @@ -9485,7 +9493,9 @@ class ChatGLMModel(TextModel): def set_gguf_parameters(self): n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + assert n_embed is not None n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + assert n_head is not None n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head)) self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4323afe57b..8f679e2fd3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -253,7 +253,7 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING - "gmml: OpenCL API version to target") + "ggml: OpenCL API version to target") option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)") diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fa9d27046b..85db02d92f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9624,7 +9624,7 @@ void ggml_compute_forward_win_unpart( } } -//gmml_compute_forward_unary +//ggml_compute_forward_unary void ggml_compute_forward_unary( const ggml_compute_params * params, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 05b826a61b..b7d587f3bd 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1156,7 +1156,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV7: return true; case GGML_OP_GATED_DELTA_NET: - return op->src[2]->ne[0] % 32 == 0; + return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0; case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 24a3092af2..107e7cf2ff 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3006,7 +3006,7 @@ kernel void kernel_l2_norm_impl( sumf = shmem_f32[tiisg]; sumf = simd_sum(sumf); - const float scale = 1.0f/sqrt(max(sumf, args.eps)); + const float scale = 1.0f/max(sqrt(sumf), args.eps); for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index c75f90730f..5bcb7ec1bc 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -470,12 +470,12 @@ static bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len) { if (audio_helpers::is_audio_file((const char *)buf, len)) { std::vector pcmf32; - int bitrate = mtmd_get_audio_bitrate(ctx); - if (bitrate < 0) { + const int sample_rate = mtmd_get_audio_sample_rate(ctx); + if (sample_rate < 0) { LOG_ERR("This model does not support audio input\n"); return nullptr; } - if (!audio_helpers::decode_audio_from_buf(buf, len, bitrate, pcmf32)) { + if (!audio_helpers::decode_audio_from_buf(buf, len, sample_rate, pcmf32)) { LOG_ERR("Unable to read WAV audio file from buffer\n"); return nullptr; } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index ccafb80b2b..1a95acd439 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -912,7 +912,7 @@ bool mtmd_support_audio(mtmd_context * ctx) { return ctx->ctx_a != nullptr; } -int mtmd_get_audio_bitrate(mtmd_context * ctx) { +int mtmd_get_audio_sample_rate(mtmd_context * ctx) { if (!ctx->ctx_a) { return -1; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index ef25d32bbe..ebb4a18fb3 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -125,9 +125,9 @@ MTMD_API bool mtmd_support_vision(mtmd_context * ctx); // whether the current model supports audio input MTMD_API bool mtmd_support_audio(mtmd_context * ctx); -// get audio bitrate in Hz, for example 16000 for Whisper +// get audio sample rate in Hz, for example 16000 for Whisper // return -1 if audio is not supported -MTMD_API int mtmd_get_audio_bitrate(mtmd_context * ctx); +MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx); // mtmd_bitmap //