diff --git a/common/common.h b/common/common.h index cf57d48415..5063d73f96 100644 --- a/common/common.h +++ b/common/common.h @@ -288,9 +288,9 @@ struct common_params { float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim + float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = -1.0f; // YaRN low correction dim + float yarn_beta_slow = -1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length // offload params diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6a86891b97..3468fc7757 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -735,6 +735,9 @@ class TextModel(ModelBase): if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c": # ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B res = "qwen2" + if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": + # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer + res = "grok-2" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -2685,12 +2688,20 @@ class BitnetModel(TextModel): yield (new_name, data_torch) -@ModelBase.register("GrokForCausalLM") +@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM") class GrokModel(TextModel): model_arch = gguf.MODEL_ARCH.GROK def set_vocab(self): - self._set_vocab_sentencepiece() + if (self.dir_model / 'tokenizer.model').is_file(): + self._set_vocab_sentencepiece() + return + + if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file(): + logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer') + sys.exit(1) + + self._set_vocab_gpt2() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2698,11 +2709,46 @@ class GrokModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - _experts: list[dict[str, Tensor]] | None = None + self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0)) + self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0)) + if (final_logit_softcap := self.hparams.get("final_logit_softcapping")): + self.gguf_writer.add_final_logit_softcapping(final_logit_softcap) + + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + + # Treat "original" as "yarn", seems to have been a mistake + if self.hparams.get("rope_type") in ("yarn", "original"): + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"]) + self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"]) + + if temp_len := self.hparams.get("attn_temperature_len"): + self.gguf_writer.add_attn_temperature_length(temp_len) + + self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5)) + self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"]) + self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"]) + + _experts: list[dict[str, list[Tensor]]] | None = None + _cur_expert = "" def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + tensors: list[tuple[str, Tensor]] = [] + is_expert = ".moe." in name or ".block_sparse_moe.experts." in name + + if not is_expert: + tensors.append((self.map_tensor_name(name), data_torch)) + # process the experts separately - if name.find(".moe.") != -1: + if is_expert or self._cur_expert: n_experts = self.hparams["num_local_experts"] assert bid is not None @@ -2710,32 +2756,41 @@ class GrokModel(TextModel): if self._experts is None: self._experts = [{} for _ in range(self.block_count)] - self._experts[bid][name] = data_torch - - if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - - # merge the experts into a single 3d tensor - for wid in ["linear", "linear_1", "linear_v"]: - datas: list[Tensor] = [] - - for xid in range(n_experts): - ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight" - datas.append(self._experts[bid][ename]) - del self._experts[bid][ename] - - data_torch = torch.stack(datas, dim=0) - - merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight" - - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors - else: + # concatenate split tensors + if name in self._experts[bid]: + self._cur_expert = name + self._experts[bid][name].append(data_torch) return [] + elif is_expert: + self._cur_expert = name + self._experts[bid][name] = [data_torch] + return [] + else: + self._cur_expert = "" - return [(self.map_tensor_name(name), data_torch)] + for bid in range(self.block_count): + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight" + if ename not in self._experts[bid]: + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight" + tensor_list = self._experts[bid][ename] + datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield (new_name, data_torch) + + yield from tensors @ModelBase.register("DbrxForCausalLM") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 6eeb1f64e3..00486efb7b 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -159,6 +159,7 @@ pre_computed_hashes = [ {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"}, {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"}, + {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, ] diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index bf724bc57b..61e3bf3015 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -57,31 +57,33 @@ static __global__ void mul_mat_f( T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); if constexpr (has_ids) { - __shared__ int has_any; - if (threadIdx.y == 0) { - int local_has_any = 0; - for (int j = threadIdx.x; j < cols_per_block; j += warp_size) { - int slot = -1; - for (int k = 0; k < nchannels_dst; ++k) { - const int idv = ids[j*stride_row_id + k*stride_col_id]; - if (idv == expert_idx) { - slot = k; - break; - } - } - if (j < cols_per_block) { - local_has_any |= (slot >= 0); - slot_map[j] = slot; + int found = 0; + + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + + if (threadIdx.x == 0) { + slot_map[j] = -1; + } + + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { + int match = id_row[k*stride_col_id] == expert_idx; + + if (match) { + slot_map[j] = k; + found = 1; + break; } } - has_any = warp_reduce_any(local_has_any); } - __syncthreads(); - if (has_any == 0) { + + if (!__syncthreads_or(found)) { return; } } + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { tile_A A[ntA][warp_size / tile_A::J]; #pragma unroll @@ -106,14 +108,7 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; } else { - float val = 0.0f; - if (j < cols_per_block) { - const int slot = slot_map[j]; - if (slot >= 0) { - val = y[slot*stride_channel_y + j*stride_col_y + col]; - } - } - tile_xy[j0*tile_k_padded + threadIdx.x] = val; + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; } } } else if constexpr (std::is_same_v || std::is_same_v) { @@ -125,14 +120,7 @@ static __global__ void mul_mat_f( const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } else { - float2 tmp = make_float2(0.0f, 0.0f); - if (j < cols_per_block) { - const int slot = slot_map[j]; - if (slot >= 0) { - const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y); - tmp = y2_slot[j*stride_col_y + col]; - } - } + float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } } @@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids( const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { if (ids) { mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 6a869ff24c..cb39e5b2ab 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -1,9 +1,12 @@ #include "ggml-metal-common.h" #include "ggml-impl.h" +#include "ggml-backend-impl.h" #include +// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb) +// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it) struct ggml_mem_range { uint64_t pb; // buffer id @@ -36,8 +39,8 @@ void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) { mrs->ranges.clear(); } -static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) { - mrs->ranges.push_back(mrp); +static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) { + mrs->ranges.push_back(mr); return true; } @@ -48,20 +51,24 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm GGML_ASSERT(!tensor->view_src); - ggml_mem_range mrp; + ggml_mem_range mr; if (tensor->buffer) { - // when the tensor is allocated, use the actual memory address range of the buffer - mrp = { + // when the tensor is allocated, use the actual memory address range in the buffer + // + // take the actual allocated size with ggml_backend_buft_get_alloc_size() + // this can be larger than the tensor size if the buffer type allocates extra memory + // ref: https://github.com/ggml-org/llama.cpp/pull/15966 + mr = { /*.pb =*/ (uint64_t) tensor->buffer, /*.p0 =*/ (uint64_t) tensor->data, - /*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor), + /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor), /*.pt =*/ pt, }; } else { - // otherwise, the tensor ptr is used as an unique id of the memory ranges + // otherwise, the pointer address is used as an unique id of the memory ranges // that the tensor will be using when it is allocated - mrp = { + mr = { /*.pb =*/ (uint64_t) tensor, /*.p0 =*/ 0, // /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used @@ -69,7 +76,7 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm }; }; - return mrp; + return mr; } static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) { @@ -83,25 +90,25 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { GGML_ASSERT(tensor); - ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor); + ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); if (mrs->debug > 2) { - GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1); + GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); } - return ggml_mem_ranges_add(mrs, mrp); + return ggml_mem_ranges_add(mrs, mr); } static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { GGML_ASSERT(tensor); - ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor); + ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); if (mrs->debug > 2) { - GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1); + GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); } - return ggml_mem_ranges_add(mrs, mrp); + return ggml_mem_ranges_add(mrs, mr); } bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { @@ -114,24 +121,26 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { return ggml_mem_ranges_add_dst(mrs, tensor); } -static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) { +static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) { for (size_t i = 0; i < mrs->ranges.size(); i++) { const auto & cmp = mrs->ranges[i]; - if (mrp.pb != cmp.pb) { + // two memory ranges cannot intersect if they are in different buffers + if (mr.pb != cmp.pb) { continue; } - if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { + // intersecting source ranges are allowed + if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { continue; } - if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) { + if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) { if (mrs->debug > 2) { GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n", __func__, - mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", - mrp.pb, mrp.p0, mrp.p1, + mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + mr.pb, mr.p0, mr.p1, cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", cmp.pb, cmp.p0, cmp.p1); } @@ -146,9 +155,9 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { GGML_ASSERT(tensor); - ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor); + ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); - const bool res = ggml_mem_ranges_check(mrs, mrp); + const bool res = ggml_mem_ranges_check(mrs, mr); return res; } @@ -156,9 +165,9 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { GGML_ASSERT(tensor); - ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor); + ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); - const bool res = ggml_mem_ranges_check(mrs, mrp); + const bool res = ggml_mem_ranges_check(mrs, mr); return res; } @@ -222,6 +231,7 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vectorsrc[i]) { @@ -290,7 +300,10 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector used(n, false); + // the memory ranges for the set of currently concurrent nodes ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0); + + // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0); for (int i0 = 0; i0 < n; i0++) { @@ -329,7 +342,7 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector order(nodes.size()); diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 13f9de297e..2243c174fb 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -532,261 +532,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COUNT }; -// -// ggml_metal_heap -// - -struct ggml_metal_heap { - // number of times the heap was unused - int n_unused; - - // total number of buffer allocations in this heap across all computes - int64_t n_alloc; - - // current offset in the heap - we reset this after each node in order to reuse the memory - size_t offs; - - // the currently allocated MTLBuffer objects in this heap - id obj; - - NSMutableArray * bufs; -}; - -static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { - struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); - - MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; - desc.storageMode = MTLStorageModePrivate; - desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; - desc.type = MTLHeapTypePlacement; - desc.size = size; - - heap->n_unused = 0; - heap->n_alloc = 0; - - heap->obj = [device newHeapWithDescriptor:desc]; - if (!heap->obj) { - GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); - - free(heap); - - return false; - } - - [desc release]; - - heap->bufs = [[NSMutableArray alloc] init]; - - return heap; -} - -static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { - heap->offs = 0; - - // count how many graph computes the heap ended up being unused - if ([heap->bufs count] > 0) { - heap->n_unused = 0; - } else { - heap->n_unused++; - } - - for (id buf in heap->bufs) { - [buf release]; - } - [heap->bufs removeAllObjects]; - - // tell the OS that it can reuse this memory if needed - // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc - [heap->obj setPurgeableState:MTLPurgeableStateVolatile]; -} - -static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { - if (heap == nil) { - return; - } - - ggml_metal_heap_reset(heap); - - [heap->obj release]; - [heap->bufs release]; - - free(heap); -} - -@interface ggml_metal_heap_ptr : NSObject - -@property (nonatomic, assign) struct ggml_metal_heap * data; - -@end - -@implementation ggml_metal_heap_ptr -@end - -// -// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE] -// - -struct ggml_metal_mem_pool { - id device; - - int n_heaps; // total number of heaps ever created (including those that were removed) - - NSMutableArray * heaps; - NSMutableArray * heaps_to_remove; -}; - -static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) { - struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool)); - - mem_pool->n_heaps = 0; - - mem_pool->heaps = [[NSMutableArray alloc] init]; - mem_pool->heaps_to_remove = [[NSMutableArray alloc] init]; - - return mem_pool; -} - -static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) { - GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps); - - size_t size_all = 0; - size_t size_cur = 0; - - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data); - GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc); - GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused); - GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0); - GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]); - - if ([ptr.data->bufs count] > 0) { - size_cur += [ptr.data->obj size]; - } - size_all += [ptr.data->obj size]; - - ggml_metal_heap_free(ptr.data); - [ptr release]; - } - [mem_pool->heaps release]; - [mem_pool->heaps_to_remove release]; - - if (size_all > 0) { - GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0); - GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0); - } - - free(mem_pool); -} - -static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { - for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) { - ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i]; - - struct ggml_metal_heap * heap = ptr.data; - ggml_metal_heap_reset(heap); - - // if the heap hasn't been used for a while, remove it - if (heap->n_unused >= 128) { - [mem_pool->heaps_to_remove addObject:@(i)]; - } - } - - if (mem_pool->heaps_to_remove.count > 0) { - // remove in reverse order - for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { - NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; - ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; - - struct ggml_metal_heap * heap = ptr.data; - ggml_metal_heap_free(heap); - - [mem_pool->heaps removeObjectAtIndex:index]; - [ptr release]; - - if (i == 0) { - break; - } - } - - [mem_pool->heaps_to_remove removeAllObjects]; - } -} - -static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - ptr.data->offs = 0; - } -} - -static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { - const size_t alignment = 256; - - const size_t size_aligned = GGML_PAD(size, alignment); - - // try one of the existing heaps - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - struct ggml_metal_heap * heap = ptr.data; - if (heap->offs + size_aligned <= [heap->obj size]) { - // if this is the first buffer in the heap for the current command buffer, tell the OS that - // it cannot free the memory used by the heap - // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc - if ([heap->bufs count] == 0) { - [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; - } - - id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; - if (buf == nil) { - GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); - return nil; - } - - heap->n_alloc++; - heap->offs += size_aligned; - - [heap->bufs addObject:buf]; - - return buf; - } - } - - // create a new heap that can fit this buffer - ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new]; - - struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned); - if (heap == NULL) { - GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned); - return NULL; - } - - //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]); - - heap_ptr.data = heap; - ggml_metal_heap_reset(heap); - - [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; - id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; - if (buf == nil) { - GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); - return NULL; - } - - heap->n_alloc++; - heap->offs += size_aligned; - - [heap->bufs addObject:buf]; - - [mem_pool->heaps addObject:heap_ptr]; - mem_pool->n_heaps++; - - return buf; -} - struct ggml_metal_command_buffer { id obj; - // each command buffer has a memory pool from which it can allocate temporary buffers during the compute - struct ggml_metal_mem_pool * mem_pool; - // used to enable concurrent execution of ops in the command buffers struct ggml_mem_ranges * mem_ranges; }; @@ -1103,9 +851,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { ctx->cmd_bufs[i].obj = nil; - ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); - ctx->cmd_bufs[i].mem_pool->device = device; - if (ctx_dev->use_concurrency) { ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph); } @@ -1510,6 +1255,52 @@ static id ggml_metal_compile_kernel(ggml_backend_t back return res; } +// tokens per expert +static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + + return ggml_type_size(GGML_TYPE_I32)*ne02; +} + +// id map [n_tokens, n_expert] +static size_t ggml_metal_mul_mat_id_extra_ids(const struct ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + const int64_t ne21 = op->src[2]->ne[1]; // n_token + + return ggml_type_size(GGML_TYPE_I32)*ne02*ne21; +} + +// return true if we should use the FA vector kernel for this op +static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + const int64_t ne00 = op->src[0]->ne[0]; // head size + const int64_t ne01 = op->src[0]->ne[1]; // batch size + + // use vec kernel if the batch size is small and if the head size is supported + return (ne01 < 20) && (ne00 % 32 == 0); +} + +static size_t ggml_metal_flash_attn_ext_extra_tmp(const struct ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + const int64_t nwg = 32; + + const int64_t ne01 = op->src[0]->ne[1]; + const int64_t ne02 = op->src[0]->ne[2]; + const int64_t ne03 = op->src[0]->ne[3]; + const int64_t ne20 = op->src[2]->ne[0]; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); +} + static id ggml_metal_get_pipeline_flash_attn_ext( ggml_backend_t backend, struct ggml_tensor * op, bool has_mask, @@ -1760,8 +1551,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->cmd_bufs[i].obj release]; } - ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); - if (ctx->cmd_bufs[i].mem_ranges) { ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges); } @@ -2127,8 +1916,6 @@ struct ggml_metal_encode_context { id encoder; - struct ggml_metal_mem_pool * mem_pool; - struct ggml_mem_ranges * mem_ranges; }; @@ -2165,8 +1952,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in id encoder = ctx_enc->encoder; - struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool; - struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; @@ -2207,8 +1992,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in GGML_ABORT("unsupported op"); } - ggml_metal_mem_pool_clear(mem_pool); - const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0; @@ -2522,7 +2305,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in /*.nb02 =*/ nb02, /*.nb11 =*/ nb11, /*.nb21 =*/ nb21, - }; [encoder setComputePipelineState:pipeline]; @@ -3167,54 +2949,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); -// use this branch to test the ggml_metal_mem_pool functionality -#if 0 - // cpy to tmp buffer in MTLHeap - - id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); - if (!h_src0) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); - return 0; - } - - offs_src0 = 0; - - ggml_metal_kargs_cpy args_cpy = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne00, - /*.ne1 =*/ ne01, - /*.ne2 =*/ ne02, - /*.ne3 =*/ ne03, - /*.nb0 =*/ nb00, - /*.nb1 =*/ nb01, - /*.nb2 =*/ nb02, - /*.nb3 =*/ nb03, - }; - - if (src0->type == GGML_TYPE_F16) { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; - } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; - } - [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:h_src0 offset:0 atIndex:2]; - - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; - -#else id h_src0 = id_src0; -#endif + // softmax ggml_metal_kargs_soft_max args = { @@ -4093,28 +3829,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in default: break; } - // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool - // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer. - // so we add this extra barrier to prevent the race. - // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE] - ggml_metal_encode_concurrency_reset(ctx_enc); - - // tokens per expert - const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; - id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); - if (!h_tpe) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); - return 0; - } - - // id map - // [n_tokens, n_expert] - const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02; - id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); - if (!h_ids) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); - return 0; - } + // extra buffers for intermediate id mapping + size_t offs_tpe = offs_dst + ggml_nbytes(dst); + size_t offs_ids = offs_tpe + ggml_metal_mul_mat_id_extra_tpe(dst); { ggml_metal_kargs_mul_mm_id_map0 args = { @@ -4152,8 +3869,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:1]; - [encoder setBuffer: h_tpe offset:0 atIndex:2]; - [encoder setBuffer: h_ids offset:0 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_tpe atIndex:2]; + [encoder setBuffer:id_dst offset:offs_ids atIndex:3]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)]; @@ -4215,8 +3932,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer: h_tpe offset:0 atIndex:3]; - [encoder setBuffer: h_ids offset:0 atIndex:4]; + [encoder setBuffer:id_dst offset:offs_tpe atIndex:3]; + [encoder setBuffer:id_dst offset:offs_ids atIndex:4]; [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; @@ -5306,8 +5023,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in GGML_ASSERT(ne01 < 65536); - // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size - if (ne01 >= 20 || (ne00 % 32 != 0)) { + if (!ggml_metal_flash_attn_ext_use_vec(dst)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! @@ -5532,34 +5248,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); - // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE] - // still, we assume that concurrent FA won't happen before we do the refactor - //ggml_metal_encode_concurrency_reset(ctx_enc); - - const int32_t nrows = ne1*ne2*ne3; - - // temp buffer for writing the results from each workgroup - // - ne20: the size of the head vector - // - + 2: the S and M values for each intermediate result - const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2)); - id h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp); - if (!h_tmp) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp); - return 0; - } - - //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20); - //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f); - - [encoder setBuffer:h_tmp offset:0 atIndex:6]; + // write the results from each workgroup into a temp buffer + const size_t offs_tmp = offs_dst + ggml_nbytes(dst); + [encoder setBuffer:id_dst offset:offs_tmp atIndex:6]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + // sync the 2 kernels ggml_metal_encode_concurrency_reset(ctx_enc); // reduce the results from the workgroups { + const int32_t nrows = ne1*ne2*ne3; + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { nrows, }; @@ -5568,7 +5270,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in [encoder setComputePipelineState:pipeline0]; [encoder setBytes:&args0 length:sizeof(args0) atIndex:0]; - [encoder setBuffer:h_tmp offset:0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_tmp atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20); @@ -5895,12 +5597,7 @@ static enum ggml_status ggml_metal_graph_compute( // the main thread commits the first few commands immediately // cmd_buf[n_cb] { - // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed - // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences - // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009 - // [TAG_MEM_POOL_REMOVE] - //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - id cmd_buf = [ctx->queue commandBuffer]; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; [cmd_buf retain]; if (ctx->cmd_bufs[n_cb].obj) { @@ -5919,8 +5616,7 @@ static enum ggml_status ggml_metal_graph_compute( // prepare the rest of the command buffers asynchronously (optional) // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - id cmd_buf = [ctx->queue commandBuffer]; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; [cmd_buf retain]; if (ctx->cmd_bufs[cb_idx].obj) { @@ -6377,6 +6073,31 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba return ggml_backend_buffer_init(buft, buf_i, ctx, size); } +static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + size_t res = ggml_nbytes(tensor); + + // some operations require additional memory for fleeting data: + switch (tensor->op) { + case GGML_OP_MUL_MAT_ID: + { + res += ggml_metal_mul_mat_id_extra_tpe(tensor); + res += ggml_metal_mul_mat_id_extra_ids(tensor); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + if (ggml_metal_flash_attn_ext_use_vec(tensor)) { + res += ggml_metal_flash_attn_ext_extra_tmp(tensor); + } + } break; + default: + break; + } + + return res; + + GGML_UNUSED(buft); +} + // default (shared) buffer type static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { @@ -6401,6 +6122,10 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu return max_size; } +static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { return false; @@ -6414,7 +6139,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, }, /* .device = */ &g_ggml_backend_metal_device, @@ -6448,6 +6173,10 @@ static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_b return max_size; } +static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { return false; @@ -6461,7 +6190,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, }, /* .device = */ &g_ggml_backend_metal_device, @@ -6496,6 +6225,10 @@ static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_bu return max_size; } +static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { return false; @@ -6511,7 +6244,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, }, /* .device = */ &g_ggml_backend_metal_device, @@ -6711,11 +6444,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; id cmd_buf = ctx->cmd_bufs[cb_idx].obj; - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges; - ggml_metal_mem_pool_reset(mem_pool); - if (mem_ranges) { ggml_mem_ranges_reset(mem_ranges); } @@ -6743,7 +6473,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { struct ggml_metal_encode_context ctx_enc = { /*.backend =*/ backend, /*.encoder =*/ encoder, - /*.mem_pool =*/ mem_pool, /*.mem_ranges =*/ mem_ranges, }; diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 0a3883ae1e..e0a1de0f32 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } +inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); @@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sub(ctx, dst); } +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_count_equal(ctx, dst); +} + void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_mul(ctx, dst); diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp index 9cce0f053a..34c4064f52 100644 --- a/ggml/src/ggml-sycl/binbcast.hpp +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) { return a - b; } +static __dpct_inline__ float op_count_equal(const float a, const float b) { + return (a == b) ? 1.0f : 0.0f; +} + +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + static __dpct_inline__ float op_mul(const float a, const float b) { return a * b; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e06ec613fc..9404e3ff4a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SUB: ggml_sycl_sub(ctx, dst); break; + case GGML_OP_COUNT_EQUAL: + ggml_sycl_count_equal(ctx, dst); + break; case GGML_OP_ACC: ggml_sycl_acc(ctx, dst); break; @@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_SUB: + case GGML_OP_COUNT_EQUAL: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_REPEAT: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d60e3ec238..7c02f611ed 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -111,6 +111,7 @@ class Keys: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" DECODER_BLOCK_COUNT = "{arch}.decoder_block_count" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" + ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" SWIN_NORM = "{arch}.swin_norm" RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" @@ -146,22 +147,28 @@ class Keys: REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" SLIDING_WINDOW = "{arch}.attention.sliding_window" SCALE = "{arch}.attention.scale" + OUTPUT_SCALE = "{arch}.attention.output_scale" + TEMPERATURE_LENGTH = "{arch}.attention.temperature_length" KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" - FREQ_BASE = "{arch}.rope.freq_base" - FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" - SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" + FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor" + SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor" + SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast" + SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow" class Split: LLM_KV_SPLIT_NO = "split.no" @@ -1117,6 +1124,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_POST_NORM, MODEL_TENSOR.LAYER_OUT_NORM, ], MODEL_ARCH.GPTNEOX: [ diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index aa4f1461fe..12b56fc2ec 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -733,6 +733,9 @@ class GGUFWriter: def add_attn_logit_softcapping(self, value: float) -> None: self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_router_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_final_logit_softcapping(self, value: float) -> None: self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) @@ -832,6 +835,12 @@ class GGUFWriter: def add_attention_scale(self, value: float) -> None: self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value) + def add_attn_output_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value) + + def add_attn_temperature_length(self, value: int) -> None: + self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) @@ -862,6 +871,18 @@ class GGUFWriter: def add_rope_scaling_yarn_log_mul(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value) + def add_rope_scaling_yarn_ext_factor(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value) + + def add_rope_scaling_yarn_attn_factor(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value) + + def add_rope_scaling_yarn_beta_fast(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value) + + def add_rope_scaling_yarn_beta_slow(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value) + def add_ssm_conv_kernel(self, value: int) -> None: self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 16c725b41e..1196e74f97 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -139,6 +139,7 @@ class TensorNameMap: "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba "transformer.decoder_layer.{bid}.rms_norm", # Grok + "model.layers.{bid}.pre_attn_norm", # grok-2 "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "encoder.layers.{bid}.input_layernorm", # chatglm "transformer.layers.{bid}.attn_norm", # openelm @@ -284,6 +285,7 @@ class TensorNameMap: "transformer.layer.{bid}.sa_layer_norm", # distillbert "encoder.layers.{bid}.norm1", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_1", # Grok + "model.layers.{bid}.post_attn_norm", # grok-2 "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx ), @@ -319,6 +321,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "model.layers.{bid}.pre_moe_norm", # grok-2 "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm "model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid @@ -340,11 +343,12 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( - "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 - "layers.{bid}.post_feedforward_layernorm", # embeddinggemma - "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 + "layers.{bid}.post_feedforward_layernorm", # embeddinggemma + "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2 "model.layers.{bid}.feed_forward.up_proj", + "model.layers.{bid}.post_moe_norm", # grok-2 ), MODEL_TENSOR.FFN_GATE_INP: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e0c07942bd..3e02a1be62 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -140,6 +140,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, @@ -170,20 +171,27 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -400,12 +408,16 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 16bb486139..fe80ff5517 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -144,6 +144,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_DECODER_BLOCK_COUNT, LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_ROUTER_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_SWIN_NORM, LLM_KV_RESCALE_EVERY_N_LAYERS, @@ -174,6 +175,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, @@ -188,6 +191,10 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, + LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_YARN_BETA_FAST, + LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 9d8e57eac1..66e6c6a38f 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -70,6 +70,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, + { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -204,6 +205,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_KIMI_K2; } else if (tmpl_contains("")) { return LLM_CHAT_TEMPLATE_SEED_OSS; + } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { + return LLM_CHAT_TEMPLATE_GROK_2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -763,6 +766,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "assistant\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "System: " << trim(message->content) << "<|separator|>\n\n"; + } else if (role == "user") { + ss << "Human: " << trim(message->content) << "<|separator|>\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << "<|separator|>\n\n"; + } + } + if (add_ass) { + ss << "Assistant:"; + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index 21d53ed08b..5a87d9ab62 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -50,6 +50,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, + LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 289a32b6d3..e6f76421cf 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -35,10 +35,10 @@ llama_context::llama_context( cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; @@ -2263,9 +2263,9 @@ llama_context_params llama_context_default_params() { /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, - /*.yarn_attn_factor =*/ 1.0f, - /*.yarn_beta_fast =*/ 32.0f, - /*.yarn_beta_slow =*/ 1.0f, + /*.yarn_attn_factor =*/ -1.0f, + /*.yarn_beta_fast =*/ -1.0f, + /*.yarn_beta_slow =*/ -1.0f, /*.yarn_orig_ctx =*/ 0, /*.defrag_thold =*/ -1.0f, /*.cb_eval =*/ nullptr, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 56f8fee784..04b76f92cf 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1336,14 +1336,14 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (arch == LLM_ARCH_GROK) { // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 + // multiply by attn_output_multiplier // and then : // kq = 30 * tanh(kq / 30) // before the softmax below - kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); + kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping)); cb(kq, "kq_tanh", il); - kq = ggml_scale(ctx0, kq, 30); + kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); cb(kq, "kq_scaled", il); } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 06f8931b2c..81325a9ccf 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -83,8 +83,9 @@ struct llama_hparams { float f_norm_rms_eps; float f_norm_group_eps; - float f_attn_logit_softcapping = 50.0f; - float f_final_logit_softcapping = 30.0f; + float f_attn_logit_softcapping = 50.0f; + float f_router_logit_softcapping = 30.0f; + float f_final_logit_softcapping = 30.0f; // for RWKV uint32_t rescale_every_n_layers = 0; @@ -105,6 +106,11 @@ struct llama_hparams { uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; + float yarn_ext_factor = -1.0f; + float yarn_attn_factor = 1.0f; + float yarn_beta_fast = 32.0f; + float yarn_beta_slow = 1.0f; + std::array rope_sections; // Sliding Window Attention (SWA) @@ -137,6 +143,10 @@ struct llama_hparams { float f_embedding_scale = 0.0f; float f_attention_scale = 0.0f; + // grok-2 + float f_attn_out_scale = 0.0f; + uint32_t attn_temp_length = 0; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 194462fb08..85458d3136 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -685,7 +685,30 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GROK: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // defaults for old GGUFs + hparams.yarn_beta_fast = 8.0f; + hparams.f_logit_scale = 0.5773502691896257f; + hparams.f_embedding_scale = 78.38367176906169f; + hparams.f_attn_out_scale = 0.08838834764831845f; + hparams.f_attn_logit_softcapping = 30.0f; + hparams.f_router_logit_softcapping = 30.0f; + // no final_logit_softcapping in grok-1 + hparams.f_final_logit_softcapping = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); switch (hparams.n_layer) { case 64: type = LLM_TYPE_314B; break; @@ -2561,6 +2584,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2575,12 +2599,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_post_norm) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } } } break; case LLM_ARCH_DBRX: @@ -7082,9 +7113,6 @@ struct llm_build_grok : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - // multiply by embedding_multiplier_scale of 78.38367176906169 - inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -7156,26 +7184,22 @@ struct llm_build_grok : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Grok - // if attn_out_norm is present then apply it before adding the input - if (model.layers[il].attn_out_norm) { - cur = build_norm(cur, - model.layers[il].attn_out_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_out_norm", il); - } + cur = build_norm(cur, + model.layers[il].attn_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_out_norm", il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); // feed-forward network - // MoE branch cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -7186,18 +7210,28 @@ struct llm_build_grok : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); - // Grok - // if layer_out_norm is present then apply it before adding the input - // Idea: maybe ffn_out_norm is a better name - if (model.layers[il].layer_out_norm) { - cur = build_norm(cur, - model.layers[il].layer_out_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "layer_out_norm", il); + if (model.layers[il].ffn_up) { + ggml_tensor * ffn_out = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(ffn_out, "ffn_out", il); + + cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -7220,10 +7254,14 @@ struct llm_build_grok : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // Grok - // multiply logits by output_multiplier_scale of 0.5773502691896257 + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); - cur = ggml_scale(ctx0, cur, 0.5773502691896257f); + // final logit soft-capping + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 7d665f2099..d7d89d99de 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GROK_2: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1975,6 +1982,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; clean_spaces = false; + } else if ( + tokenizer_pre == "grok-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 61b8124216..0d2f28c36c 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -47,6 +47,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, }; struct LLM_KV; diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 80cbb095da..c9fd082db9 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -1931,7 +1931,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { LOG("Maximum KLD: %10.6f\n", kld_values.back()); LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); - LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f)); LOG("Median KLD: %10.6f\n", kld_median); LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 6fe728c685..772d66c921 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -407,39 +407,22 @@ class HttpClient { } std::string output_file_partial; - curl = curl_easy_init(); - if (!curl) { - return 1; - } - progress_data data; - File out; if (!output_file.empty()) { output_file_partial = output_file + ".partial"; - if (!out.open(output_file_partial, "ab")) { - printe("Failed to open file for writing\n"); - - return 1; - } - - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - - return 1; - } } - set_write_options(response_str, out); - data.file_size = set_resume_point(output_file_partial); - set_progress_options(progress, data); - set_headers(headers); - CURLcode res = perform(url); - if (res != CURLE_OK){ - printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); + if (download(url, headers, output_file_partial, progress, response_str)) { return 1; } + if (!output_file.empty()) { - std::filesystem::rename(output_file_partial, output_file); + try { + std::filesystem::rename(output_file_partial, output_file); + } catch (const std::filesystem::filesystem_error & e) { + printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what()); + return 1; + } } return 0; @@ -459,6 +442,42 @@ class HttpClient { CURL * curl = nullptr; struct curl_slist * chunk = nullptr; + int download(const std::string & url, const std::vector & headers, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + curl = curl_easy_init(); + if (!curl) { + return 1; + } + + progress_data data; + File out; + if (!output_file.empty()) { + if (!out.open(output_file, "ab")) { + printe("Failed to open file for writing\n"); + + return 1; + } + + if (out.lock()) { + printe("Failed to exclusively lock file\n"); + + return 1; + } + } + + set_write_options(response_str, out); + data.file_size = set_resume_point(output_file); + set_progress_options(progress, data); + set_headers(headers); + CURLcode res = perform(url); + if (res != CURLE_OK){ + printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); + return 1; + } + + return 0; + } + void set_write_options(std::string * response_str, const File & out) { if (response_str) { curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); @@ -507,6 +526,9 @@ class HttpClient { curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); +#ifdef _WIN32 + curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif return curl_easy_perform(curl); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 160b97cf7d..519704fad7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2313,7 +2313,7 @@ struct server_context { // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it - const bool enable_thinking = params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); SRV_INF("Enable thinking? %d\n", enable_thinking); oai_parser_opt = {