Merge branch 'master' into HEAD
This commit is contained in:
commit
16451d6bc3
|
|
@ -19,6 +19,7 @@ The project differentiates between 3 levels of contributors:
|
|||
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
|
||||
- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR
|
||||
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs
|
||||
- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure.
|
||||
|
||||
# Pull requests (for maintainers)
|
||||
|
||||
|
|
|
|||
|
|
@ -65,4 +65,6 @@ However, If you have discovered a security vulnerability in this project, please
|
|||
|
||||
Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new).
|
||||
|
||||
Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report.
|
||||
|
||||
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
|
||||
|
|
|
|||
|
|
@ -980,7 +980,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
|
|
@ -2646,7 +2646,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params &, const std::string & value) {
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
).set_env("LLAMA_LOG_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
|
|
|
|||
|
|
@ -517,16 +517,18 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
|
||||
}
|
||||
|
||||
std::atomic<size_t> downloaded{existing_size};
|
||||
const char * func = __func__; // avoid __func__ inside a lambda
|
||||
size_t downloaded = existing_size;
|
||||
size_t progress_step = 0;
|
||||
|
||||
auto res = cli.Get(resolve_path, headers,
|
||||
[&](const httplib::Response &response) {
|
||||
if (existing_size > 0 && response.status != 206) {
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status);
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (existing_size == 0 && response.status != 200) {
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status);
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (total_size == 0 && response.has_header("Content-Length")) {
|
||||
|
|
@ -534,7 +536,7 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
|
||||
total_size = existing_size + content_length;
|
||||
} catch (const std::exception &e) {
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what());
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
@ -542,11 +544,16 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
[&](const char *data, size_t len) {
|
||||
ofs.write(data, len);
|
||||
if (!ofs) {
|
||||
LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str());
|
||||
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
|
||||
return false;
|
||||
}
|
||||
downloaded += len;
|
||||
progress_step += len;
|
||||
|
||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||
print_progress(downloaded, total_size);
|
||||
progress_step = 0;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
nullptr
|
||||
|
|
|
|||
|
|
@ -1581,10 +1581,27 @@ class MmprojModel(ModelBase):
|
|||
|
||||
# load preprocessor config
|
||||
self.preprocessor_config = {}
|
||||
if not self.is_mistral_format:
|
||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||
|
||||
# prefer preprocessor_config.json if possible
|
||||
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
|
||||
if preprocessor_config_path.is_file():
|
||||
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
|
||||
self.preprocessor_config = json.load(f)
|
||||
|
||||
# prefer processor_config.json if possible
|
||||
processor_config_path = self.dir_model / "processor_config.json"
|
||||
if processor_config_path.is_file():
|
||||
with open(processor_config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
# move image_processor to root level for compat
|
||||
if "image_processor" in cfg:
|
||||
cfg = {
|
||||
**cfg,
|
||||
**cfg["image_processor"],
|
||||
}
|
||||
# merge configs
|
||||
self.preprocessor_config = {**self.preprocessor_config, **cfg}
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||
return self.global_config.get(config_name)
|
||||
|
|
@ -2797,7 +2814,32 @@ class Llama4VisionModel(MmprojModel):
|
|||
|
||||
@ModelBase.register("Mistral3ForConditionalGeneration")
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone has migrated to newer version of llama.cpp
|
||||
if self.hparams.get("model_type") != "ministral3":
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
rope_params = self.hparams.get("rope_parameters")
|
||||
if self.hparams.get("model_type") == "ministral3":
|
||||
assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
|
||||
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
|
||||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
|
|
@ -9809,12 +9851,22 @@ class ApertusModel(LlamaModel):
|
|||
|
||||
|
||||
class MistralModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
model_name = "Mistral"
|
||||
hf_arch = ""
|
||||
is_mistral_format = True
|
||||
undo_permute = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone migrates to newer version of llama.cpp
|
||||
if "llama_4_scaling" not in self.hparams:
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
|
|
@ -9854,6 +9906,20 @@ class MistralModel(LlamaModel):
|
|||
|
||||
return template
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if "yarn" in self.hparams:
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in self.hparams:
|
||||
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
|
||||
|
||||
|
||||
class PixtralModel(LlavaVisionModel):
|
||||
model_name = "Pixtral"
|
||||
|
|
|
|||
|
|
@ -2148,7 +2148,8 @@ extern "C" {
|
|||
};
|
||||
|
||||
enum ggml_scale_flag {
|
||||
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
|
||||
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
|
||||
GGML_SCALE_FLAG_ANTIALIAS = (1 << 9),
|
||||
};
|
||||
|
||||
// interpolate
|
||||
|
|
|
|||
|
|
@ -274,10 +274,13 @@ function(ggml_add_backend_library backend)
|
|||
endif()
|
||||
|
||||
# Set versioning properties for all backend libraries
|
||||
# Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)
|
||||
if (NOT (APPLE AND GGML_BACKEND_DL))
|
||||
set_target_properties(${backend} PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
endif()
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
|
|
|
|||
|
|
@ -723,6 +723,12 @@ struct ggml_backend_sched {
|
|||
bool op_offload;
|
||||
|
||||
int debug;
|
||||
|
||||
// used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC]
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17617
|
||||
int debug_realloc;
|
||||
int debug_graph_size;
|
||||
int debug_prev_graph_size;
|
||||
};
|
||||
|
||||
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
|
||||
|
|
@ -1289,6 +1295,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
|
|||
}
|
||||
|
||||
int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;
|
||||
|
||||
// remember the actual graph_size for performing reallocation checks later [GGML_SCHED_DEBUG_REALLOC]
|
||||
sched->debug_prev_graph_size = sched->debug_graph_size;
|
||||
sched->debug_graph_size = graph_size;
|
||||
|
||||
if (sched->graph.size < graph_size) {
|
||||
sched->graph.size = graph_size;
|
||||
sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
|
||||
|
|
@ -1395,14 +1406,21 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
|||
|
||||
// allocate graph
|
||||
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
||||
#ifdef GGML_SCHED_NO_REALLOC
|
||||
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
||||
#endif
|
||||
|
||||
if (sched->debug_realloc > 0) {
|
||||
// we are interested only in situations where the graph was reallocated even though its size remained the same [GGML_SCHED_DEBUG_REALLOC]
|
||||
// example: https://github.com/ggml-org/llama.cpp/pull/17143
|
||||
const bool unexpected = !backend_ids_changed && sched->debug_prev_graph_size == sched->debug_graph_size;
|
||||
|
||||
if (unexpected || sched->debug_realloc > 1) {
|
||||
GGML_ABORT("%s: unexpected graph reallocation (graph size = %d, nodes = %d, leafs = %d), debug_realloc = %d\n", __func__,
|
||||
sched->debug_graph_size, sched->graph.n_nodes, sched->graph.n_leafs, sched->debug_realloc);
|
||||
}
|
||||
}
|
||||
|
||||
// the re-allocation may cause the split inputs to be moved to a different address
|
||||
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
|
||||
for (int i = 0; i < sched->n_backends; i++) {
|
||||
|
|
@ -1620,6 +1638,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
|||
|
||||
const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG");
|
||||
sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0;
|
||||
|
||||
sched->debug_realloc = 0;
|
||||
#ifdef GGML_SCHED_NO_REALLOC
|
||||
sched->debug_realloc = 1;
|
||||
#endif
|
||||
const char * GGML_SCHED_DEBUG_REALLOC = getenv("GGML_SCHED_DEBUG_REALLOC");
|
||||
sched->debug_realloc = GGML_SCHED_DEBUG_REALLOC ? atoi(GGML_SCHED_DEBUG_REALLOC) : sched->debug_realloc;
|
||||
|
||||
sched->n_backends = n_backends;
|
||||
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
|
||||
|
||||
|
|
@ -1636,6 +1662,9 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
|||
sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
|
||||
sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
|
||||
|
||||
sched->debug_graph_size = 0;
|
||||
sched->debug_prev_graph_size = 0;
|
||||
|
||||
sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
|
||||
sched->context_buffer = (char *) malloc(sched->context_buffer_size);
|
||||
|
||||
|
|
|
|||
|
|
@ -2500,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
|
||||
return false;
|
||||
}
|
||||
if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
|
|||
|
|
@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
|
||||
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
||||
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
||||
auto triangle_filter = [](float x) -> float {
|
||||
return std::max(1.0f - fabsf(x), 0.0f);
|
||||
};
|
||||
|
||||
// support and invscale, minimum 1 pixel for bilinear
|
||||
const float support1 = std::max(1.0f, 1.0f / sf1);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
const float support0 = std::max(1.0f, 1.0f / sf0);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2 / sf2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const float y = ((float) i1 + pixel_offset) / sf1;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const float x = ((float) i0 + pixel_offset) / sf0;
|
||||
|
||||
// the range of source pixels that contribute
|
||||
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
|
||||
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
|
||||
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
|
||||
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
|
||||
|
||||
// bilinear filter with antialiasing
|
||||
float val = 0.0f;
|
||||
float total_weight = 0.0f;
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; sy++) {
|
||||
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
||||
|
||||
for (int64_t sx = x_min; sx < x_max; sx++) {
|
||||
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
||||
const float weight = weight_x * weight_y;
|
||||
|
||||
if (weight <= 0.0f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
|
||||
val += pixel * weight;
|
||||
total_weight += weight;
|
||||
}
|
||||
}
|
||||
|
||||
if (total_weight > 0.0f) {
|
||||
val /= total_weight;
|
||||
}
|
||||
|
||||
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
*dst_ptr = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@
|
|||
#include "ggml-common.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
|
|
@ -980,6 +982,154 @@ struct ggml_cuda_graph {
|
|||
#endif
|
||||
};
|
||||
|
||||
struct ggml_cuda_concurrent_event {
|
||||
std::vector<cudaEvent_t> join_events;
|
||||
cudaEvent_t fork_event = nullptr;
|
||||
|
||||
int n_streams = 0;
|
||||
std::unordered_map<const ggml_tensor *, int> stream_mapping;
|
||||
|
||||
const ggml_tensor * join_node;
|
||||
|
||||
ggml_cuda_concurrent_event() = default;
|
||||
|
||||
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
|
||||
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
|
||||
|
||||
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
|
||||
join_events.resize(n_streams);
|
||||
|
||||
for (size_t i = 0; i < join_events.size(); ++i) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
|
||||
: join_events(std::move(other.join_events))
|
||||
, fork_event(other.fork_event)
|
||||
, n_streams(other.n_streams)
|
||||
, stream_mapping(std::move(other.stream_mapping))
|
||||
, join_node(other.join_node) {
|
||||
other.fork_event = nullptr;
|
||||
}
|
||||
|
||||
// 1. check if any branches write to overlapping memory ranges (except the join node)
|
||||
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
|
||||
// we assume all nodes have the same buffer
|
||||
bool is_valid() const {
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
|
||||
write_ranges.resize(n_streams);
|
||||
|
||||
// get join_node's memory range to exclude from overlap checking.
|
||||
// multiple nodes can use join_node's buffer; we synchronize on the join node.
|
||||
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
|
||||
const int64_t join_start = (int64_t) join_t->data;
|
||||
const int64_t join_end = join_start + ggml_nbytes(join_t);
|
||||
|
||||
for (const auto & [tensor, stream] : stream_mapping) {
|
||||
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
||||
const int64_t t_start = (int64_t) t->data;
|
||||
const int64_t t_end = t_start + ggml_nbytes(t);
|
||||
|
||||
// skip tensors that overlap with join_node's buffer.
|
||||
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// concurrent streams begin from 1
|
||||
write_ranges[stream - 1].emplace_back(t_start, t_end);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_streams; ++i) {
|
||||
// sorts first by start then by end of write range
|
||||
std::sort(write_ranges[i].begin(), write_ranges[i].end());
|
||||
}
|
||||
|
||||
bool writes_overlap = false;
|
||||
bool dependent_srcs = false;
|
||||
for (const auto & [tensor, stream] : stream_mapping) {
|
||||
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
||||
const int64_t t_start = (int64_t) t->data;
|
||||
const int64_t t_end = t_start + ggml_nbytes(t);
|
||||
|
||||
// skip tensors that overlap with join_node's buffer
|
||||
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if this buffer's write data overlaps with another stream's
|
||||
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
|
||||
for (int i = 0; i < n_streams; ++i) {
|
||||
if (i == stream - 1) {
|
||||
continue;
|
||||
}
|
||||
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
|
||||
|
||||
if (it != write_ranges[i].end()) {
|
||||
const std::pair<int64_t, int64_t> & other = *it;
|
||||
|
||||
// std::lower_bound returns the first element where other >= data_range (lexicographically).
|
||||
// This guarantees other.first >= data_range.first.
|
||||
// Therefore, overlap occurs iff other.first < data_range.second
|
||||
// (i.e., the other range starts before this range ends).
|
||||
if (other.first < data_range.second) {
|
||||
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
|
||||
writes_overlap = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//check if all srcs are either in branch or don't have a branch
|
||||
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||
if (!tensor->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = stream_mapping.find(tensor->src[i]);
|
||||
|
||||
if (it == stream_mapping.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (it->second != stream) {
|
||||
dependent_srcs = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (dependent_srcs || writes_overlap) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return !writes_overlap && !dependent_srcs;
|
||||
}
|
||||
|
||||
~ggml_cuda_concurrent_event() {
|
||||
if (fork_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(fork_event));
|
||||
}
|
||||
for (cudaEvent_t e : join_events) {
|
||||
if (e != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cuda_stream_context {
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
|
||||
|
||||
void reset() {
|
||||
original_nodes.clear();
|
||||
concurrent_events.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_backend_cuda_context {
|
||||
int device;
|
||||
std::string name;
|
||||
|
|
@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context {
|
|||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
}
|
||||
|
||||
ggml_cuda_stream_context concurrent_stream_context;
|
||||
|
||||
~ggml_backend_cuda_context();
|
||||
|
||||
cudaStream_t stream(int device, int stream) {
|
||||
|
|
@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context {
|
|||
return streams[device][stream];
|
||||
}
|
||||
|
||||
cudaStream_t stream() {
|
||||
return stream(device, 0);
|
||||
}
|
||||
cudaStream_t stream() { return stream(device, curr_stream_no); }
|
||||
|
||||
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
|
||||
|
||||
cublasHandle_t cublas_handle(int device) {
|
||||
if (cublas_handles[device] == nullptr) {
|
||||
|
|
@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context {
|
|||
}
|
||||
|
||||
// pool
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
|
||||
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
|
||||
|
||||
ggml_cuda_pool & pool(int device) {
|
||||
if (pools[device] == nullptr) {
|
||||
pools[device] = new_pool_for_device(device);
|
||||
if (pools[device][curr_stream_no] == nullptr) {
|
||||
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
|
||||
}
|
||||
return *pools[device];
|
||||
return *pools[device][curr_stream_no];
|
||||
}
|
||||
|
||||
ggml_cuda_pool & pool() {
|
||||
|
|
|
|||
|
|
@ -524,7 +524,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
|||
};
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
|
||||
[[maybe_unused]] int stream_no) {
|
||||
#if defined(GGML_USE_VMM)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
|
|
@ -3208,27 +3209,94 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
// flag used to determine whether it is an integrated_gpu
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
|
||||
bool is_concurrent_event_active = false;
|
||||
ggml_cuda_concurrent_event * concurrent_event = nullptr;
|
||||
bool should_launch_concurrent_events = false;
|
||||
|
||||
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
|
||||
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
|
||||
concurrent_event = &stream_ctx.concurrent_events[node];
|
||||
|
||||
is_concurrent_event_active = true;
|
||||
|
||||
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
|
||||
|
||||
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
|
||||
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
|
||||
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
|
||||
|
||||
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
|
||||
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while (!graph_evaluated_or_captured) {
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
[[maybe_unused]] int prev_i = 0;
|
||||
|
||||
if (stream_ctx.concurrent_events.size() > 0) {
|
||||
should_launch_concurrent_events = true;
|
||||
for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
|
||||
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
|
||||
}
|
||||
}
|
||||
if (should_launch_concurrent_events) {
|
||||
//Restore the original graph to enable fusion within the streams
|
||||
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
|
||||
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (is_concurrent_event_active) {
|
||||
GGML_ASSERT(concurrent_event);
|
||||
|
||||
if (node == concurrent_event->join_node) {
|
||||
cuda_ctx->curr_stream_no = 0;
|
||||
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
|
||||
// Wait on join events of forked streams in the main stream
|
||||
CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
|
||||
cuda_ctx->stream(cuda_ctx->device, i)));
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
|
||||
}
|
||||
|
||||
is_concurrent_event_active = false;
|
||||
concurrent_event = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
|
||||
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
|
||||
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
|
||||
}
|
||||
} else if (i - prev_i > 1) {
|
||||
//the previous node was fused
|
||||
const ggml_tensor * prev_node = cgraph->nodes[i - 1];
|
||||
try_launch_concurrent_event(prev_node);
|
||||
|
||||
if (is_concurrent_event_active) {
|
||||
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
|
||||
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
if (nodes_fused > 0) {
|
||||
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
|
||||
}
|
||||
#endif
|
||||
prev_i = i;
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
|
||||
|
|
@ -3528,6 +3596,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
|
||||
if (!is_concurrent_event_active) {
|
||||
try_launch_concurrent_event(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3667,6 +3739,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
return env != nullptr && atoi(env) == 1;
|
||||
}();
|
||||
|
||||
if (!enable_graph_optimization) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
|
||||
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
|
||||
|
||||
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
|
||||
stream_context.reset();
|
||||
|
||||
// number of out-degrees for a particular node
|
||||
std::unordered_map<const ggml_tensor *, int> fan_out;
|
||||
// reverse mapping of node to index in the cgraph
|
||||
std::unordered_map<const ggml_tensor *, int> node_indices;
|
||||
|
||||
const auto & is_noop = [](const ggml_tensor * node) -> bool {
|
||||
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
|
||||
node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
||||
};
|
||||
|
||||
const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
|
||||
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
|
||||
if (dst->src[s] == src) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// implicit dependency if they view the same tensor
|
||||
const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
|
||||
const ggml_tensor * src2 = src->view_src ? src->view_src : src;
|
||||
if (dst2 == src2) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
|
||||
const ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
node_indices[node] = node_idx;
|
||||
|
||||
if (is_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
|
||||
//TODO: check why nrows > 1 fails
|
||||
if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
|
||||
fan_out[src] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Target Q, K, V for concurrency
|
||||
// this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
|
||||
// 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
|
||||
// 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
|
||||
// 3. account for all branches from the fork to the join
|
||||
// 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
|
||||
// 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
|
||||
|
||||
const int min_fan_out = 3;
|
||||
const int max_fan_out = 3;
|
||||
|
||||
// store {fork_idx, join_idx}
|
||||
std::vector<std::pair<int, int>> concurrent_node_ranges;
|
||||
|
||||
// save the original nodes
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
original_nodes.reserve(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
original_nodes.push_back(cgraph->nodes[i]);
|
||||
}
|
||||
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
|
||||
|
||||
for (const auto & [root_node, count] : fan_out) {
|
||||
if (count >= min_fan_out && count <= max_fan_out) {
|
||||
const int root_node_idx = node_indices[root_node];
|
||||
|
||||
bool is_part_of_event = false;
|
||||
for (const auto & [start, end] : concurrent_node_ranges) {
|
||||
if (root_node_idx >= start && root_node_idx <= end) {
|
||||
is_part_of_event = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_part_of_event) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
|
||||
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
|
||||
const ggml_tensor * node = cgraph->nodes[i];
|
||||
if (!is_noop(node) && depends_on(node, root_node)) {
|
||||
nodes_per_branch.push_back({ node });
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
|
||||
|
||||
//find the join point
|
||||
const ggml_tensor * join_node = nullptr;
|
||||
|
||||
const auto & belongs_to_branch = [&](const ggml_tensor * node,
|
||||
const std::vector<const ggml_tensor *> & branch) -> bool {
|
||||
for (const ggml_tensor * n : branch) {
|
||||
if (depends_on(node, n)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
|
||||
const ggml_tensor * curr_node = cgraph->nodes[i];
|
||||
|
||||
int num_joins = 0;
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
|
||||
num_joins++;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_joins >= 2) {
|
||||
join_node = curr_node;
|
||||
break;
|
||||
}
|
||||
|
||||
bool found_branch = false;
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
|
||||
if (belongs_to_branch(curr_node, branch_vec)) {
|
||||
//continue accumulating
|
||||
if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
|
||||
branch_vec.push_back(curr_node);
|
||||
}
|
||||
found_branch = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_branch && is_noop(curr_node)) {
|
||||
// we can put it in any branch because it will be ignored
|
||||
nodes_per_branch[0].push_back({ curr_node });
|
||||
}
|
||||
}
|
||||
|
||||
if (join_node) {
|
||||
//Create ggml_cuda_concurrent_event
|
||||
ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
|
||||
concurrent_event.join_node = join_node;
|
||||
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
|
||||
concurrent_event.stream_mapping[n] = branch_idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
int fork_node_idx = node_indices[root_node];
|
||||
int join_node_idx = node_indices[join_node];
|
||||
|
||||
int current_branch_idx = 0;
|
||||
int current_node_idx = fork_node_idx + 1;
|
||||
const int n_branches = nodes_per_branch.size();
|
||||
|
||||
int total_branch_nodes = 0;
|
||||
for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
|
||||
total_branch_nodes += branch_nodes.size();
|
||||
}
|
||||
|
||||
// there are other nodes in the middle which are unaccounted for
|
||||
// usually (cpy) nodes, then ignore this fork
|
||||
if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
|
||||
GGML_LOG_DEBUG(
|
||||
"Skipping %s because the number of nodes in the middle is not equal to the total number of "
|
||||
"branch nodes %d != %d\n",
|
||||
root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
|
||||
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
|
||||
concurrent_events.emplace(root_node, std::move(concurrent_event));
|
||||
GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
|
||||
concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
|
||||
|
||||
// interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
|
||||
// example transformation:
|
||||
// [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
|
||||
// [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
|
||||
while (current_node_idx < join_node_idx) {
|
||||
std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
|
||||
|
||||
bool has_node = false;
|
||||
for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
|
||||
has_node |= branch_node.size() > 0;
|
||||
}
|
||||
|
||||
GGML_ASSERT(has_node);
|
||||
|
||||
if (branch_nodes.empty()) {
|
||||
current_branch_idx = (current_branch_idx + 1) % n_branches;
|
||||
continue;
|
||||
}
|
||||
|
||||
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
|
||||
current_node_idx++;
|
||||
branch_nodes.erase(branch_nodes.begin());
|
||||
|
||||
// append all empty nodes
|
||||
while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
|
||||
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
|
||||
current_node_idx++;
|
||||
branch_nodes.erase(branch_nodes.begin());
|
||||
}
|
||||
|
||||
current_branch_idx = (current_branch_idx + 1) % n_branches;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_get_name,
|
||||
/* .free = */ ggml_backend_cuda_free,
|
||||
|
|
@ -3681,7 +3982,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
|||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||
/* .graph_optimize = */ NULL,
|
||||
/* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cuda_guid() {
|
||||
|
|
|
|||
|
|
@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
|||
dst[index] = result;
|
||||
}
|
||||
|
||||
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
||||
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
||||
static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y = ((float)i11_dst + pixel_offset) / sf1;
|
||||
const float x = ((float)i10_dst + pixel_offset) / sf0;
|
||||
|
||||
// support and invscale, minimum 1 pixel for bilinear
|
||||
const float support1 = max(1.0f / sf1, 1.0f);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
const float support0 = max(1.0f / sf0, 1.0f);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
|
||||
// the range of source pixels that contribute
|
||||
const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
|
||||
const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
|
||||
const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
|
||||
const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
|
||||
|
||||
// bilinear filter with antialiasing
|
||||
float val = 0.0f;
|
||||
float total_weight = 0.0f;
|
||||
|
||||
auto triangle_filter = [](float x) -> float {
|
||||
return max(1.0f - fabsf(x), 0.0f);
|
||||
};
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; sy++) {
|
||||
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
||||
|
||||
for (int64_t sx = x_min; sx < x_max; sx++) {
|
||||
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
||||
const float weight = weight_x * weight_y;
|
||||
|
||||
if (weight <= 0.0f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
|
||||
val += pixel * weight;
|
||||
total_weight += weight;
|
||||
}
|
||||
}
|
||||
|
||||
if (total_weight > 0.0f) {
|
||||
val /= total_weight;
|
||||
}
|
||||
|
||||
dst[index] = val;
|
||||
}
|
||||
|
||||
namespace bicubic_interpolation {
|
||||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
||||
|
|
@ -161,12 +231,16 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
|||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset, cudaStream_t stream) {
|
||||
const float pixel_offset, bool antialias, cudaStream_t stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
if (antialias) {
|
||||
upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
} else {
|
||||
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
}
|
||||
}
|
||||
|
||||
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
|
|
@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
||||
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@
|
|||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamPerThread hipStreamPerThread
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaStreamWaitEvent hipStreamWaitEvent
|
||||
#define cudaGraphExec_t hipGraphExec_t
|
||||
#define cudaGraphNode_t hipGraphNode_t
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
|
|
|
|||
|
|
@ -894,7 +894,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
case GGML_OP_POOL_2D:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_PAD:
|
||||
|
|
@ -912,6 +912,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
// for new head sizes, add checks here
|
||||
if (op->src[0]->ne[0] != 32 &&
|
||||
op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 48 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 72 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
|
|
|
|||
|
|
@ -5757,6 +5757,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||
|
|
@ -5770,6 +5771,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
|
|
@ -5784,6 +5786,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
|
|||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
|
|
@ -5798,6 +5801,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
|
|
@ -5811,6 +5815,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
|
|
@ -5824,6 +5829,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
|
|
@ -5837,6 +5843,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
|
|
@ -5850,6 +5857,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
|
|||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48, 48>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
|
|
|
|||
|
|
@ -3086,8 +3086,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE: {
|
||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
||||
const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
|
||||
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR);
|
||||
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias;
|
||||
}
|
||||
case GGML_OP_CONV_2D:
|
||||
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
|
||||
|
|
|
|||
|
|
@ -4597,7 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_IM2COL:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
|
|
|
|||
|
|
@ -14113,6 +14113,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
}
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
case GGML_OP_ACC:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CONCAT:
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ void main() {
|
|||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
|
|
|
|||
|
|
@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl(
|
|||
int64_t ne3,
|
||||
uint32_t mode) {
|
||||
GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
|
||||
// TODO: implement antialias for modes other than bilinear
|
||||
GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
||||
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ class Keys:
|
|||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
||||
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
|
||||
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
|
|
@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum):
|
|||
MINIMAXM2 = auto()
|
||||
RND1 = auto()
|
||||
PANGU_EMBED = auto()
|
||||
MISTRAL3 = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -817,6 +819,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.COGVLM: "cogvlm",
|
||||
MODEL_ARCH.RND1: "rnd1",
|
||||
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
|
||||
MODEL_ARCH.MISTRAL3: "mistral3",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -3071,6 +3074,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MISTRAL3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -904,6 +904,9 @@ class GGUFWriter:
|
|||
def add_attn_temperature_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
|
||||
|
||||
def add_attn_temperature_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ add_library(llama
|
|||
models/t5-enc.cpp
|
||||
models/wavtokenizer-dec.cpp
|
||||
models/xverse.cpp
|
||||
models/mistral3.cpp
|
||||
models/graph-context-mamba.cpp
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_COGVLM, "cogvlm" },
|
||||
{ LLM_ARCH_RND1, "rnd1" },
|
||||
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
|
||||
|
|
@ -2512,6 +2514,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MISTRAL3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ enum llm_arch {
|
|||
LLM_ARCH_COGVLM,
|
||||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -208,6 +209,7 @@ enum llm_kv {
|
|||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|||
if (ubatch->pos && attn_scale) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(f_attn_temp_scale != 0.0f);
|
||||
GGML_ASSERT(n_attn_temp_floor_scale != 0);
|
||||
|
||||
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const float pos = ubatch->pos[i];
|
||||
|
|
@ -837,9 +840,6 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
//expand here so that we can fuse ffn gate
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
|
|
@ -1120,9 +1120,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
//expand here so that we can fuse ffn gate
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
|
|
|
|||
|
|
@ -162,8 +162,8 @@ struct llama_hparams {
|
|||
// llama4 smallthinker
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
float f_attn_temp_scale = 0.1;
|
||||
uint32_t n_attn_temp_floor_scale = 0;
|
||||
float f_attn_temp_scale = 0.0f;
|
||||
|
||||
// gemma3n altup
|
||||
uint32_t n_altup = 4; // altup_num_inputs
|
||||
|
|
|
|||
|
|
@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
if (hparams.n_expert == 8) {
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_8x7B; break;
|
||||
|
|
@ -665,6 +663,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192;
|
||||
hparams.n_attn_temp_floor_scale = 8192;
|
||||
hparams.f_attn_temp_scale = 0.1f;
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
}
|
||||
|
||||
|
|
@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||
|
||||
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
|
||||
if (hparams.n_attn_temp_floor_scale == 0) {
|
||||
throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
|
||||
// but may need further verification with other values
|
||||
if (hparams.rope_yarn_log_mul != 0.0f) {
|
||||
float factor = 1.0f / hparams.rope_freq_scale_train;
|
||||
float mscale = 1.0f;
|
||||
float mscale_all_dims = hparams.rope_yarn_log_mul;
|
||||
static auto get_mscale = [](float scale, float mscale) {
|
||||
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
||||
};
|
||||
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 26: type = LLM_TYPE_3B; break;
|
||||
case 34: type = LLM_TYPE_8B; break;
|
||||
case 40: type = LLM_TYPE_14B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
|
@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
llm = std::make_unique<llm_build_mistral3>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -7693,6 +7734,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_ARCEE:
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
|
|
|||
|
|
@ -0,0 +1,160 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// (optional) temperature tuning
|
||||
ggml_tensor * inp_attn_scale = nullptr;
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
inp_attn_scale = build_inp_attn_scale();
|
||||
}
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
// apply llama 4 temperature scaling
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_attn_temp_scaled", il);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network (non-MoE)
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -322,6 +322,10 @@ struct llm_build_minimax_m2 : public llm_graph_context {
|
|||
llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mistral3 : public llm_graph_context {
|
||||
llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mpt : public llm_graph_context {
|
||||
llm_build_mpt(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -7660,7 +7660,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
|
||||
//}
|
||||
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
|
||||
|
|
|
|||
|
|
@ -517,6 +517,12 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = params.interactive_first;
|
||||
}
|
||||
|
||||
LOG_WRN("*****************************\n");
|
||||
LOG_WRN("IMPORTANT: The current llama-cli will be moved to llama-completion in the near future\n");
|
||||
LOG_WRN(" New llama-cli will have enhanced features and improved user experience\n");
|
||||
LOG_WRN(" More info: https://github.com/ggml-org/llama.cpp/discussions/17618\n");
|
||||
LOG_WRN("*****************************\n");
|
||||
|
||||
bool is_antiprompt = false;
|
||||
bool input_echo = true;
|
||||
bool display = true;
|
||||
|
|
|
|||
|
|
@ -987,12 +987,20 @@ struct clip_graph {
|
|||
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], 0);
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], n_embd * sizeof(float));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], 2 * n_embd * sizeof(float));
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ 0);
|
||||
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ ggml_row_size(cur->type, n_embd));
|
||||
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
|
@ -2012,7 +2020,7 @@ private:
|
|||
ggml_tensor * pos_embd = model.position_embeddings;
|
||||
const int height = img.ny / patch_size;
|
||||
const int width = img.nx / patch_size;
|
||||
const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
|
||||
const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS;
|
||||
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
|
||||
|
||||
GGML_ASSERT(pos_embd);
|
||||
|
|
@ -2787,7 +2795,8 @@ struct clip_model_loader {
|
|||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
// ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json
|
||||
hparams.set_limit_image_tokens(64, 256);
|
||||
// config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64
|
||||
hparams.set_limit_image_tokens(64, 1024);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
|
|
@ -3737,12 +3746,13 @@ struct img_tool {
|
|||
const int width = inp_size.width;
|
||||
const int height = inp_size.height;
|
||||
|
||||
auto round_by_factor = [f = align_size](float x) { return static_cast<int>(std::round(x / static_cast<float>(f))) * f; };
|
||||
auto ceil_by_factor = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
|
||||
auto floor_by_factor = [f = align_size](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
|
||||
|
||||
// always align up first
|
||||
int h_bar = std::max(align_size, ceil_by_factor(height));
|
||||
int w_bar = std::max(align_size, ceil_by_factor(width));
|
||||
int h_bar = std::max(align_size, round_by_factor(height));
|
||||
int w_bar = std::max(align_size, round_by_factor(width));
|
||||
|
||||
if (h_bar * w_bar > max_pixels) {
|
||||
const auto beta = std::sqrt(static_cast<float>(height * width) / max_pixels);
|
||||
|
|
@ -4357,7 +4367,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
const std::array<uint8_t, 3> pad_color = {122, 116, 104};
|
||||
|
||||
clip_image_u8 resized_img;
|
||||
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
|
||||
const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2);
|
||||
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color);
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
|
|
|
|||
|
|
@ -304,6 +304,10 @@ struct mtmd_context {
|
|||
img_beg = "<|im_start|>";
|
||||
img_end = "<|im_end|>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_LFM2) {
|
||||
img_beg = "<|image_start|>";
|
||||
img_end = "<|image_end|>";
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ target_compile_definitions(${TARGET} PRIVATE
|
|||
CPPHTTPLIB_TCP_NODELAY=1
|
||||
)
|
||||
|
||||
set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code when building BoringSSL or LibreSSL")
|
||||
|
||||
if (LLAMA_BUILD_BORINGSSL)
|
||||
set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code (BoringSSL)")
|
||||
set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)")
|
||||
|
||||
set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository")
|
||||
|
|
@ -64,6 +65,47 @@ if (LLAMA_BUILD_BORINGSSL)
|
|||
set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE)
|
||||
target_link_libraries(${TARGET} PUBLIC ssl crypto)
|
||||
|
||||
elseif (LLAMA_BUILD_LIBRESSL)
|
||||
set(LIBRESSL_VERSION "4.2.1" CACHE STRING "LibreSSL version")
|
||||
|
||||
message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}")
|
||||
|
||||
set(LIBRESSL_ARGS
|
||||
URL "https://cdn.openbsd.org/pub/OpenBSD/LibreSSL/libressl-${LIBRESSL_VERSION}.tar.gz"
|
||||
)
|
||||
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.24)
|
||||
list(APPEND LIBRESSL_ARGS DOWNLOAD_EXTRACT_TIMESTAMP TRUE)
|
||||
endif()
|
||||
|
||||
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
|
||||
list(APPEND LIBRESSL_ARGS EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(libressl ${LIBRESSL_ARGS})
|
||||
|
||||
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||
set(SAVED_BUILD_TESTING ${BUILD_TESTING})
|
||||
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
set(BUILD_TESTING OFF)
|
||||
|
||||
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
|
||||
FetchContent_MakeAvailable(libressl)
|
||||
else()
|
||||
FetchContent_GetProperties(libressl)
|
||||
if(NOT libressl_POPULATED)
|
||||
FetchContent_Populate(libressl)
|
||||
add_subdirectory(${libressl_SOURCE_DIR} ${libressl_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
|
||||
set(BUILD_TESTING ${SAVED_BUILD_TESTING})
|
||||
|
||||
set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE)
|
||||
target_link_libraries(${TARGET} PUBLIC ssl crypto)
|
||||
|
||||
elseif (LLAMA_OPENSSL)
|
||||
find_package(OpenSSL)
|
||||
if (OpenSSL_FOUND)
|
||||
|
|
|
|||
Loading…
Reference in New Issue