Merge branch 'master' into modern-bert-support

This commit is contained in:
Ryan Mangeno 2025-09-15 17:29:06 -04:00 committed by GitHub
commit e0438154a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 519 additions and 562 deletions

View File

@ -288,9 +288,9 @@ struct common_params {
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = -1.0f; // YaRN low correction dim
float yarn_beta_slow = -1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
// offload params

View File

@ -735,6 +735,9 @@ class TextModel(ModelBase):
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@ -2685,12 +2688,20 @@ class BitnetModel(TextModel):
yield (new_name, data_torch)
@ModelBase.register("GrokForCausalLM")
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
class GrokModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROK
def set_vocab(self):
self._set_vocab_sentencepiece()
if (self.dir_model / 'tokenizer.model').is_file():
self._set_vocab_sentencepiece()
return
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
sys.exit(1)
self._set_vocab_gpt2()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -2698,11 +2709,46 @@ class GrokModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
_experts: list[dict[str, Tensor]] | None = None
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
# Treat "original" as "yarn", seems to have been a mistake
if self.hparams.get("rope_type") in ("yarn", "original"):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
if temp_len := self.hparams.get("attn_temperature_len"):
self.gguf_writer.add_attn_temperature_length(temp_len)
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
_experts: list[dict[str, list[Tensor]]] | None = None
_cur_expert = ""
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors: list[tuple[str, Tensor]] = []
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
if not is_expert:
tensors.append((self.map_tensor_name(name), data_torch))
# process the experts separately
if name.find(".moe.") != -1:
if is_expert or self._cur_expert:
n_experts = self.hparams["num_local_experts"]
assert bid is not None
@ -2710,32 +2756,41 @@ class GrokModel(TextModel):
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for wid in ["linear", "linear_1", "linear_v"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
# concatenate split tensors
if name in self._experts[bid]:
self._cur_expert = name
self._experts[bid][name].append(data_torch)
return []
elif is_expert:
self._cur_expert = name
self._experts[bid][name] = [data_torch]
return []
else:
self._cur_expert = ""
return [(self.map_tensor_name(name), data_torch)]
for bid in range(self.block_count):
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
if ename not in self._experts[bid]:
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
tensor_list = self._experts[bid][ename]
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
new_name = self.map_tensor_name(merged_name)
yield (new_name, data_torch)
yield from tensors
@ModelBase.register("DbrxForCausalLM")

View File

@ -159,6 +159,7 @@ pre_computed_hashes = [
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
]

View File

@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
if constexpr (has_ids) {
__shared__ int has_any;
if (threadIdx.y == 0) {
int local_has_any = 0;
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
int slot = -1;
for (int k = 0; k < nchannels_dst; ++k) {
const int idv = ids[j*stride_row_id + k*stride_col_id];
if (idv == expert_idx) {
slot = k;
break;
}
}
if (j < cols_per_block) {
local_has_any |= (slot >= 0);
slot_map[j] = slot;
int found = 0;
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
if (threadIdx.x == 0) {
slot_map[j] = -1;
}
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
int match = id_row[k*stride_col_id] == expert_idx;
if (match) {
slot_map[j] = k;
found = 1;
break;
}
}
has_any = warp_reduce_any(local_has_any);
}
__syncthreads();
if (has_any == 0) {
if (!__syncthreads_or(found)) {
return;
}
}
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
if constexpr (!has_ids) {
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
} else {
float val = 0.0f;
if (j < cols_per_block) {
const int slot = slot_map[j];
if (slot >= 0) {
val = y[slot*stride_channel_y + j*stride_col_y + col];
}
}
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
} else {
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block) {
const int slot = slot_map[j];
if (slot >= 0) {
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
tmp = y2_slot[j*stride_col_y + col];
}
}
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
}
@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
if (ids) {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {

View File

@ -1,9 +1,12 @@
#include "ggml-metal-common.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include <vector>
// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
struct ggml_mem_range {
uint64_t pb; // buffer id
@ -36,8 +39,8 @@ void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
mrs->ranges.clear();
}
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) {
mrs->ranges.push_back(mrp);
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
mrs->ranges.push_back(mr);
return true;
}
@ -48,20 +51,24 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
GGML_ASSERT(!tensor->view_src);
ggml_mem_range mrp;
ggml_mem_range mr;
if (tensor->buffer) {
// when the tensor is allocated, use the actual memory address range of the buffer
mrp = {
// when the tensor is allocated, use the actual memory address range in the buffer
//
// take the actual allocated size with ggml_backend_buft_get_alloc_size()
// this can be larger than the tensor size if the buffer type allocates extra memory
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
mr = {
/*.pb =*/ (uint64_t) tensor->buffer,
/*.p0 =*/ (uint64_t) tensor->data,
/*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor),
/*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
/*.pt =*/ pt,
};
} else {
// otherwise, the tensor ptr is used as an unique id of the memory ranges
// otherwise, the pointer address is used as an unique id of the memory ranges
// that the tensor will be using when it is allocated
mrp = {
mr = {
/*.pb =*/ (uint64_t) tensor,
/*.p0 =*/ 0, //
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
@ -69,7 +76,7 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
};
};
return mrp;
return mr;
}
static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
@ -83,25 +90,25 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
}
return ggml_mem_ranges_add(mrs, mrp);
return ggml_mem_ranges_add(mrs, mr);
}
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
}
return ggml_mem_ranges_add(mrs, mrp);
return ggml_mem_ranges_add(mrs, mr);
}
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
@ -114,24 +121,26 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
return ggml_mem_ranges_add_dst(mrs, tensor);
}
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) {
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
for (size_t i = 0; i < mrs->ranges.size(); i++) {
const auto & cmp = mrs->ranges[i];
if (mrp.pb != cmp.pb) {
// two memory ranges cannot intersect if they are in different buffers
if (mr.pb != cmp.pb) {
continue;
}
if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
// intersecting source ranges are allowed
if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
continue;
}
if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) {
if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
__func__,
mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
mrp.pb, mrp.p0, mrp.p1,
mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
mr.pb, mr.p0, mr.p1,
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
cmp.pb, cmp.p0, cmp.p1);
}
@ -146,9 +155,9 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
const bool res = ggml_mem_ranges_check(mrs, mrp);
const bool res = ggml_mem_ranges_check(mrs, mr);
return res;
}
@ -156,9 +165,9 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
const bool res = ggml_mem_ranges_check(mrs, mrp);
const bool res = ggml_mem_ranges_check(mrs, mr);
return res;
}
@ -222,6 +231,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
}
}
// keep track of the sources of the fused nodes as well
for (const auto * fused : node.fused) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (fused->src[i]) {
@ -290,7 +300,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
std::vector<bool> used(n, false);
// the memory ranges for the set of currently concurrent nodes
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
for (int i0 = 0; i0 < n; i0++) {
@ -329,7 +342,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
const bool is_empty = node1.is_empty();
// to add a concurrent node, it has to be:
// to reorder a node and add it to the concurrent set, it has to be:
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
@ -419,8 +432,8 @@ void ggml_metal_graph_optimize(ggml_cgraph * gf) {
nodes.push_back(std::move(node));
}
// reorder to improve concurrency
#if 1
// reorder to improve concurrency
const auto order = ggml_metal_graph_optimize_reorder(nodes);
#else
std::vector<int> order(nodes.size());

View File

@ -532,261 +532,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COUNT
};
//
// ggml_metal_heap
//
struct ggml_metal_heap {
// number of times the heap was unused
int n_unused;
// total number of buffer allocations in this heap across all computes
int64_t n_alloc;
// current offset in the heap - we reset this after each node in order to reuse the memory
size_t offs;
// the currently allocated MTLBuffer objects in this heap
id<MTLHeap> obj;
NSMutableArray * bufs;
};
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypePlacement;
desc.size = size;
heap->n_unused = 0;
heap->n_alloc = 0;
heap->obj = [device newHeapWithDescriptor:desc];
if (!heap->obj) {
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
free(heap);
return false;
}
[desc release];
heap->bufs = [[NSMutableArray alloc] init];
return heap;
}
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
heap->offs = 0;
// count how many graph computes the heap ended up being unused
if ([heap->bufs count] > 0) {
heap->n_unused = 0;
} else {
heap->n_unused++;
}
for (id<MTLBuffer> buf in heap->bufs) {
[buf release];
}
[heap->bufs removeAllObjects];
// tell the OS that it can reuse this memory if needed
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
}
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
if (heap == nil) {
return;
}
ggml_metal_heap_reset(heap);
[heap->obj release];
[heap->bufs release];
free(heap);
}
@interface ggml_metal_heap_ptr : NSObject
@property (nonatomic, assign) struct ggml_metal_heap * data;
@end
@implementation ggml_metal_heap_ptr
@end
//
// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
//
struct ggml_metal_mem_pool {
id<MTLDevice> device;
int n_heaps; // total number of heaps ever created (including those that were removed)
NSMutableArray * heaps;
NSMutableArray * heaps_to_remove;
};
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
mem_pool->n_heaps = 0;
mem_pool->heaps = [[NSMutableArray alloc] init];
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
return mem_pool;
}
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
size_t size_all = 0;
size_t size_cur = 0;
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
if ([ptr.data->bufs count] > 0) {
size_cur += [ptr.data->obj size];
}
size_all += [ptr.data->obj size];
ggml_metal_heap_free(ptr.data);
[ptr release];
}
[mem_pool->heaps release];
[mem_pool->heaps_to_remove release];
if (size_all > 0) {
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
}
free(mem_pool);
}
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
struct ggml_metal_heap * heap = ptr.data;
ggml_metal_heap_reset(heap);
// if the heap hasn't been used for a while, remove it
if (heap->n_unused >= 128) {
[mem_pool->heaps_to_remove addObject:@(i)];
}
}
if (mem_pool->heaps_to_remove.count > 0) {
// remove in reverse order
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
struct ggml_metal_heap * heap = ptr.data;
ggml_metal_heap_free(heap);
[mem_pool->heaps removeObjectAtIndex:index];
[ptr release];
if (i == 0) {
break;
}
}
[mem_pool->heaps_to_remove removeAllObjects];
}
}
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
ptr.data->offs = 0;
}
}
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
const size_t alignment = 256;
const size_t size_aligned = GGML_PAD(size, alignment);
// try one of the existing heaps
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
struct ggml_metal_heap * heap = ptr.data;
if (heap->offs + size_aligned <= [heap->obj size]) {
// if this is the first buffer in the heap for the current command buffer, tell the OS that
// it cannot free the memory used by the heap
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
if ([heap->bufs count] == 0) {
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
}
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
if (buf == nil) {
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
return nil;
}
heap->n_alloc++;
heap->offs += size_aligned;
[heap->bufs addObject:buf];
return buf;
}
}
// create a new heap that can fit this buffer
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
if (heap == NULL) {
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
return NULL;
}
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
heap_ptr.data = heap;
ggml_metal_heap_reset(heap);
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
if (buf == nil) {
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
return NULL;
}
heap->n_alloc++;
heap->offs += size_aligned;
[heap->bufs addObject:buf];
[mem_pool->heaps addObject:heap_ptr];
mem_pool->n_heaps++;
return buf;
}
struct ggml_metal_command_buffer {
id<MTLCommandBuffer> obj;
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
struct ggml_metal_mem_pool * mem_pool;
// used to enable concurrent execution of ops in the command buffers
struct ggml_mem_ranges * mem_ranges;
};
@ -1103,9 +851,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->cmd_bufs[i].obj = nil;
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs[i].mem_pool->device = device;
if (ctx_dev->use_concurrency) {
ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
}
@ -1510,6 +1255,52 @@ static id<MTLComputePipelineState> ggml_metal_compile_kernel(ggml_backend_t back
return res;
}
// tokens per expert
static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_MUL_MAT_ID);
const int64_t ne02 = op->src[0]->ne[2]; // n_expert
return ggml_type_size(GGML_TYPE_I32)*ne02;
}
// id map [n_tokens, n_expert]
static size_t ggml_metal_mul_mat_id_extra_ids(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_MUL_MAT_ID);
const int64_t ne02 = op->src[0]->ne[2]; // n_expert
const int64_t ne21 = op->src[2]->ne[1]; // n_token
return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
}
// return true if we should use the FA vector kernel for this op
static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t ne00 = op->src[0]->ne[0]; // head size
const int64_t ne01 = op->src[0]->ne[1]; // batch size
// use vec kernel if the batch size is small and if the head size is supported
return (ne01 < 20) && (ne00 % 32 == 0);
}
static size_t ggml_metal_flash_attn_ext_extra_tmp(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t nwg = 32;
const int64_t ne01 = op->src[0]->ne[1];
const int64_t ne02 = op->src[0]->ne[2];
const int64_t ne03 = op->src[0]->ne[3];
const int64_t ne20 = op->src[2]->ne[0];
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
}
static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
ggml_backend_t backend, struct ggml_tensor * op,
bool has_mask,
@ -1760,8 +1551,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->cmd_bufs[i].obj release];
}
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
if (ctx->cmd_bufs[i].mem_ranges) {
ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
}
@ -2127,8 +1916,6 @@ struct ggml_metal_encode_context {
id<MTLComputeCommandEncoder> encoder;
struct ggml_metal_mem_pool * mem_pool;
struct ggml_mem_ranges * mem_ranges;
};
@ -2165,8 +1952,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
@ -2207,8 +1992,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
GGML_ABORT("unsupported op");
}
ggml_metal_mem_pool_clear(mem_pool);
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
@ -2522,7 +2305,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
/*.nb02 =*/ nb02,
/*.nb11 =*/ nb11,
/*.nb21 =*/ nb21,
};
[encoder setComputePipelineState:pipeline];
@ -3167,54 +2949,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// use this branch to test the ggml_metal_mem_pool functionality
#if 0
// cpy to tmp buffer in MTLHeap
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
if (!h_src0) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
return 0;
}
offs_src0 = 0;
ggml_metal_kargs_cpy args_cpy = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne00,
/*.ne1 =*/ ne01,
/*.ne2 =*/ ne02,
/*.ne3 =*/ ne03,
/*.nb0 =*/ nb00,
/*.nb1 =*/ nb01,
/*.nb2 =*/ nb02,
/*.nb3 =*/ nb03,
};
if (src0->type == GGML_TYPE_F16) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
}
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:h_src0 offset:0 atIndex:2];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
#else
id<MTLBuffer> h_src0 = id_src0;
#endif
// softmax
ggml_metal_kargs_soft_max args = {
@ -4093,28 +3829,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
default: break;
}
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
// so we add this extra barrier to prevent the race.
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
ggml_metal_encode_concurrency_reset(ctx_enc);
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
if (!h_tpe) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
return 0;
}
// id map
// [n_tokens, n_expert]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
return 0;
}
// extra buffers for intermediate id mapping
size_t offs_tpe = offs_dst + ggml_nbytes(dst);
size_t offs_ids = offs_tpe + ggml_metal_mul_mat_id_extra_tpe(dst);
{
ggml_metal_kargs_mul_mm_id_map0 args = {
@ -4152,8 +3869,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
[encoder setBuffer: h_tpe offset:0 atIndex:2];
[encoder setBuffer: h_ids offset:0 atIndex:3];
[encoder setBuffer:id_dst offset:offs_tpe atIndex:2];
[encoder setBuffer:id_dst offset:offs_ids atIndex:3];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
@ -4215,8 +3932,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
[encoder setBuffer: h_ids offset:0 atIndex:4];
[encoder setBuffer:id_dst offset:offs_tpe atIndex:3];
[encoder setBuffer:id_dst offset:offs_ids atIndex:4];
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
@ -5306,8 +5023,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
GGML_ASSERT(ne01 < 65536);
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
if (ne01 >= 20 || (ne00 % 32 != 0)) {
if (!ggml_metal_flash_attn_ext_use_vec(dst)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
@ -5532,34 +5248,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
// still, we assume that concurrent FA won't happen before we do the refactor
//ggml_metal_encode_concurrency_reset(ctx_enc);
const int32_t nrows = ne1*ne2*ne3;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the head vector
// - + 2: the S and M values for each intermediate result
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
if (!h_tmp) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
return 0;
}
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
[encoder setBuffer:h_tmp offset:0 atIndex:6];
// write the results from each workgroup into a temp buffer
const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
[encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// sync the 2 kernels
ggml_metal_encode_concurrency_reset(ctx_enc);
// reduce the results from the workgroups
{
const int32_t nrows = ne1*ne2*ne3;
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
nrows,
};
@ -5568,7 +5270,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
[encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
[encoder setBuffer:h_tmp offset:0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_tmp atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
@ -5895,12 +5597,7 @@ static enum ggml_status ggml_metal_graph_compute(
// the main thread commits the first few commands immediately
// cmd_buf[n_cb]
{
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
// [TAG_MEM_POOL_REMOVE]
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
[cmd_buf retain];
if (ctx->cmd_bufs[n_cb].obj) {
@ -5919,8 +5616,7 @@ static enum ggml_status ggml_metal_graph_compute(
// prepare the rest of the command buffers asynchronously (optional)
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
[cmd_buf retain];
if (ctx->cmd_bufs[cb_idx].obj) {
@ -6377,6 +6073,31 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
return ggml_backend_buffer_init(buft, buf_i, ctx, size);
}
static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
size_t res = ggml_nbytes(tensor);
// some operations require additional memory for fleeting data:
switch (tensor->op) {
case GGML_OP_MUL_MAT_ID:
{
res += ggml_metal_mul_mat_id_extra_tpe(tensor);
res += ggml_metal_mul_mat_id_extra_ids(tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
res += ggml_metal_flash_attn_ext_extra_tmp(tensor);
}
} break;
default:
break;
}
return res;
GGML_UNUSED(buft);
}
// default (shared) buffer type
static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
@ -6401,6 +6122,10 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
return max_size;
}
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
}
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
return false;
@ -6414,7 +6139,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
/* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
/* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
},
/* .device = */ &g_ggml_backend_metal_device,
@ -6448,6 +6173,10 @@ static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_b
return max_size;
}
static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
}
static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
return false;
@ -6461,7 +6190,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
/* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
/* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
},
/* .device = */ &g_ggml_backend_metal_device,
@ -6496,6 +6225,10 @@ static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_bu
return max_size;
}
static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
}
static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
return false;
@ -6511,7 +6244,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
/* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
/* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
},
/* .device = */ &g_ggml_backend_metal_device,
@ -6711,11 +6444,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
ggml_metal_mem_pool_reset(mem_pool);
if (mem_ranges) {
ggml_mem_ranges_reset(mem_ranges);
}
@ -6743,7 +6473,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
struct ggml_metal_encode_context ctx_enc = {
/*.backend =*/ backend,
/*.encoder =*/ encoder,
/*.mem_pool =*/ mem_pool,
/*.mem_ranges =*/ mem_ranges,
};

View File

@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);

View File

@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}
static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}

View File

@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SUB:
ggml_sycl_sub(ctx, dst);
break;
case GGML_OP_COUNT_EQUAL:
ggml_sycl_count_equal(ctx, dst);
break;
case GGML_OP_ACC:
ggml_sycl_acc(ctx, dst);
break;
@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT:

View File

@ -111,6 +111,7 @@ class Keys:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
@ -146,22 +147,28 @@ class Keys:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_SCALE = "{arch}.attention.output_scale"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
class Split:
LLM_KV_SPLIT_NO = "split.no"
@ -1117,6 +1124,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.GPTNEOX: [

View File

@ -733,6 +733,9 @@ class GGUFWriter:
def add_attn_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
def add_router_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
@ -832,6 +835,12 @@ class GGUFWriter:
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
def add_attn_output_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
def add_attn_temperature_length(self, value: int) -> None:
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
@ -862,6 +871,18 @@ class GGUFWriter:
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
def add_ssm_conv_kernel(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)

View File

@ -139,6 +139,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"model.layers.{bid}.pre_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
@ -284,6 +285,7 @@ class TensorNameMap:
"transformer.layer.{bid}.sa_layer_norm", # distillbert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"model.layers.{bid}.post_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),
@ -319,6 +321,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"model.layers.{bid}.pre_moe_norm", # grok-2
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
@ -340,11 +343,12 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
"model.layers.{bid}.feed_forward.up_proj",
"model.layers.{bid}.post_moe_norm", # grok-2
),
MODEL_TENSOR.FFN_GATE_INP: (

View File

@ -140,6 +140,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
@ -170,20 +171,27 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
{ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" },
{ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" },
{ LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" },
{ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" },
{ LLM_KV_SPLIT_NO, "split.no" },
{ LLM_KV_SPLIT_COUNT, "split.count" },
@ -400,12 +408,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
},

View File

@ -144,6 +144,7 @@ enum llm_kv {
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_DECODER_BLOCK_COUNT,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_SWIN_NORM,
LLM_KV_RESCALE_EVERY_N_LAYERS,
@ -174,6 +175,8 @@ enum llm_kv {
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
@ -188,6 +191,10 @@ enum llm_kv {
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
LLM_KV_SPLIT_NO,
LLM_KV_SPLIT_COUNT,

View File

@ -70,6 +70,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -204,6 +205,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("<seed:bos>")) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -763,6 +766,20 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<seed:bos>assistant\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "user") {
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << "<|separator|>\n\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;

View File

@ -50,6 +50,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View File

@ -35,10 +35,10 @@ llama_context::llama_context(
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
@ -2263,9 +2263,9 @@ llama_context_params llama_context_default_params() {
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
/*.yarn_attn_factor =*/ 1.0f,
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_attn_factor =*/ -1.0f,
/*.yarn_beta_fast =*/ -1.0f,
/*.yarn_beta_slow =*/ -1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr,

View File

@ -1336,14 +1336,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
if (arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// multiply by attn_output_multiplier
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
cb(kq, "kq_tanh", il);
kq = ggml_scale(ctx0, kq, 30);
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
cb(kq, "kq_scaled", il);
}

View File

@ -83,8 +83,9 @@ struct llama_hparams {
float f_norm_rms_eps;
float f_norm_group_eps;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
float f_attn_logit_softcapping = 50.0f;
float f_router_logit_softcapping = 30.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
@ -105,6 +106,11 @@ struct llama_hparams {
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
float yarn_ext_factor = -1.0f;
float yarn_attn_factor = 1.0f;
float yarn_beta_fast = 32.0f;
float yarn_beta_slow = 1.0f;
std::array<int, 4> rope_sections;
// Sliding Window Attention (SWA)
@ -137,6 +143,10 @@ struct llama_hparams {
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
// grok-2
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;

View File

@ -685,7 +685,30 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GROK:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
// defaults for old GGUFs
hparams.yarn_beta_fast = 8.0f;
hparams.f_logit_scale = 0.5773502691896257f;
hparams.f_embedding_scale = 78.38367176906169f;
hparams.f_attn_out_scale = 0.08838834764831845f;
hparams.f_attn_logit_softcapping = 30.0f;
hparams.f_router_logit_softcapping = 30.0f;
// no final_logit_softcapping in grok-1
hparams.f_final_logit_softcapping = 0.0f;
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
switch (hparams.n_layer) {
case 64: type = LLM_TYPE_314B; break;
@ -2561,6 +2584,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
@ -2575,12 +2599,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
if (!layer.ffn_post_norm) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
}
} break;
case LLM_ARCH_DBRX:
@ -7082,9 +7113,6 @@ struct llm_build_grok : public llm_graph_context {
inpL = build_inp_embd(model.tok_embd);
// multiply by embedding_multiplier_scale of 78.38367176906169
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
@ -7156,26 +7184,22 @@ struct llm_build_grok : public llm_graph_context {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// Grok
// if attn_out_norm is present then apply it before adding the input
if (model.layers[il].attn_out_norm) {
cur = build_norm(cur,
model.layers[il].attn_out_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_out_norm", il);
}
cur = build_norm(cur,
model.layers[il].attn_out_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_out_norm", il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur,
// MoE branch
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
@ -7186,18 +7210,28 @@ struct llm_build_grok : public llm_graph_context {
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(cur, "ffn_moe_out", il);
cb(moe_out, "ffn_moe_out", il);
// Grok
// if layer_out_norm is present then apply it before adding the input
// Idea: maybe ffn_out_norm is a better name
if (model.layers[il].layer_out_norm) {
cur = build_norm(cur,
model.layers[il].layer_out_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "layer_out_norm", il);
if (model.layers[il].ffn_up) {
ggml_tensor * ffn_out = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(ffn_out, "ffn_out", il);
cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
cb(cur, "ffn_out", il);
} else {
cur = moe_out;
}
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@ -7220,10 +7254,14 @@ struct llm_build_grok : public llm_graph_context {
// lm_head
cur = build_lora_mm(model.output, cur);
// Grok
// multiply logits by output_multiplier_scale of 0.5773502691896257
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
// final logit soft-capping
if (hparams.f_final_logit_softcapping) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1);
res->t_logits = cur;

View File

@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_GROK_2:
regex_exprs = {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@ -1975,6 +1982,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "kimi-k2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
clean_spaces = false;
} else if (
tokenizer_pre == "grok-2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}

View File

@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
};
struct LLM_KV;

View File

@ -1931,7 +1931,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG("Maximum KLD: %10.6f\n", kld_values.back());
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
LOG("Median KLD: %10.6f\n", kld_median);
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));

View File

@ -407,39 +407,22 @@ class HttpClient {
}
std::string output_file_partial;
curl = curl_easy_init();
if (!curl) {
return 1;
}
progress_data data;
File out;
if (!output_file.empty()) {
output_file_partial = output_file + ".partial";
if (!out.open(output_file_partial, "ab")) {
printe("Failed to open file for writing\n");
return 1;
}
if (out.lock()) {
printe("Failed to exclusively lock file\n");
return 1;
}
}
set_write_options(response_str, out);
data.file_size = set_resume_point(output_file_partial);
set_progress_options(progress, data);
set_headers(headers);
CURLcode res = perform(url);
if (res != CURLE_OK){
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
if (download(url, headers, output_file_partial, progress, response_str)) {
return 1;
}
if (!output_file.empty()) {
std::filesystem::rename(output_file_partial, output_file);
try {
std::filesystem::rename(output_file_partial, output_file);
} catch (const std::filesystem::filesystem_error & e) {
printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what());
return 1;
}
}
return 0;
@ -459,6 +442,42 @@ class HttpClient {
CURL * curl = nullptr;
struct curl_slist * chunk = nullptr;
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
curl = curl_easy_init();
if (!curl) {
return 1;
}
progress_data data;
File out;
if (!output_file.empty()) {
if (!out.open(output_file, "ab")) {
printe("Failed to open file for writing\n");
return 1;
}
if (out.lock()) {
printe("Failed to exclusively lock file\n");
return 1;
}
}
set_write_options(response_str, out);
data.file_size = set_resume_point(output_file);
set_progress_options(progress, data);
set_headers(headers);
CURLcode res = perform(url);
if (res != CURLE_OK){
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
return 1;
}
return 0;
}
void set_write_options(std::string * response_str, const File & out) {
if (response_str) {
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data);
@ -507,6 +526,9 @@ class HttpClient {
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
#ifdef _WIN32
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
return curl_easy_perform(curl);
}

View File

@ -2313,7 +2313,7 @@ struct server_context {
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
SRV_INF("Enable thinking? %d\n", enable_thinking);
oai_parser_opt = {