Merge branch 'master' into compilade/mamba2

This commit is contained in:
Francis Couture-Harpin 2025-06-23 10:40:16 -04:00
commit afdb669206
99 changed files with 5168 additions and 4226 deletions

View File

@ -683,7 +683,7 @@ jobs:
env: env:
OPENBLAS_VERSION: 0.3.23 OPENBLAS_VERSION: 0.3.23
SDE_VERSION: 9.33.0-2024-01-07 SDE_VERSION: 9.33.0-2024-01-07
VULKAN_VERSION: 1.4.309.0 VULKAN_VERSION: 1.4.313.2
strategy: strategy:
matrix: matrix:
@ -736,7 +736,7 @@ jobs:
id: get_vulkan id: get_vulkan
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }} if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
run: | run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"

View File

@ -302,7 +302,7 @@ jobs:
env: env:
OPENBLAS_VERSION: 0.3.23 OPENBLAS_VERSION: 0.3.23
VULKAN_VERSION: 1.4.309.0 VULKAN_VERSION: 1.4.313.2
strategy: strategy:
matrix: matrix:
@ -332,7 +332,7 @@ jobs:
id: get_vulkan id: get_vulkan
if: ${{ matrix.backend == 'vulkan' }} if: ${{ matrix.backend == 'vulkan' }}
run: | run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"

View File

@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf" model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>" # for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output # sample output
# rerank score 0: 0.029 # rerank score 0: 0.029

View File

@ -2706,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.embd_sep = value; params.embd_sep = value;
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING})); ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--cls-separator"}, "STRING",
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
[](common_params & params, const std::string & value) {
params.cls_sep = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg( add_opt(common_arg(
{"--host"}, "HOST", {"--host"}, "HOST",
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()), string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
@ -3210,6 +3217,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.model.path = value; params.speculative.model.path = value;
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
"KV cache data type for K for the draft model\n"
"allowed values: %s\n"
"(default: %s)",
get_all_kv_cache_types().c_str(),
ggml_type_name(params.speculative.cache_type_k)
),
[](common_params & params, const std::string & value) {
params.speculative.cache_type_k = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
add_opt(common_arg(
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
string_format(
"KV cache data type for V for the draft model\n"
"allowed values: %s\n"
"(default: %s)",
get_all_kv_cache_types().c_str(),
ggml_type_name(params.speculative.cache_type_v)
),
[](common_params & params, const std::string & value) {
params.speculative.cache_type_v = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
add_opt(common_arg( add_opt(common_arg(
{"-mv", "--model-vocoder"}, "FNAME", {"-mv", "--model-vocoder"}, "FNAME",

View File

@ -706,11 +706,17 @@ bool fs_validate_filename(const std::string & filename) {
// disable C++17 deprecation warning for std::codecvt_utf8 // disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push # pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations" # pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif #endif
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter; std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic pop # pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif #endif
filename_utf32 = converter.from_bytes(filename); filename_utf32 = converter.from_bytes(filename);
@ -1284,6 +1290,9 @@ std::vector<llama_token> common_tokenize(
int n_tokens = text.length() + 2 * add_special; int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens); std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens == std::numeric_limits<int32_t>::min()) {
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
}
if (n_tokens < 0) { if (n_tokens < 0) {
result.resize(-n_tokens); result.resize(-n_tokens);
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);

View File

@ -199,6 +199,9 @@ struct common_params_speculative {
float p_split = 0.1f; // speculative decoding split probability float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy) float p_min = 0.75f; // minimum speculative decoding probability (greedy)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
struct cpu_params cpuparams; struct cpu_params cpuparams;
struct cpu_params cpuparams_batch; struct cpu_params cpuparams_batch;
@ -355,6 +358,7 @@ struct common_params {
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings std::string embd_sep = "\n"; // separator of embeddings
std::string cls_sep = "\t"; // separator of classification sequences
// server params // server params
int32_t port = 8080; // server listens on this network port int32_t port = 8080; // server listens on this network port

View File

@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result; return result;
} }
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
class string_view {
const std::string & _str;
const size_t _start;
const size_t _end;
public:
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
size_t size() const {
return _end - _start;
}
size_t length() const {
return size();
}
operator std::string() const {
return str();
}
std::string str() const {
return _str.substr(_start, _end - _start);
}
string_view substr(size_t pos, size_t len = std::string::npos) const {
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
}
char operator[](size_t pos) const {
auto index = _start + pos;
if (index >= _end) {
throw std::out_of_range("string_view index out of range");
}
return _str[_start + pos];
}
bool operator==(const string_view & other) const {
std::string this_str = *this;
std::string other_str = other;
return this_str == other_str;
}
};
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min(); auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max(); auto has_max = max_value != std::numeric_limits<int>::max();
@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
} }
out << "}"; out << "}";
}; };
std::function<void(const string_view &, const string_view &)> uniform_range = std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
[&](const string_view & from, const string_view & to) { [&](const std::string_view & from, const std::string_view & to) {
size_t i = 0; size_t i = 0;
while (i < from.length() && i < to.length() && from[i] == to[i]) { while (i < from.length() && i < to.length() && from[i] == to[i]) {
i++; i++;
} }
if (i > 0) { if (i > 0) {
out << "\"" << from.substr(0, i).str() << "\""; out << "\"" << from.substr(0, i) << "\"";
} }
if (i < from.length() && i < to.length()) { if (i < from.length() && i < to.length()) {
if (i > 0) { if (i > 0) {

View File

@ -2145,7 +2145,6 @@ class Llama4Model(LlamaModel):
def set_vocab(self): def set_vocab(self):
self._set_vocab_gpt2() self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
@ -2194,7 +2193,7 @@ class Llama4VisionModel(MmprojModel):
name += ".weight" name += ".weight"
if "multi_modal_projector.linear_1" in name: if "multi_modal_projector.linear_1" in name:
# despite the name with number postfix, this is a single fully connected layer # despite the name with number postfix, this is a single fully connected layer
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)] return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)]
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
return [] return []
@ -3918,9 +3917,6 @@ class BertModel(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification") @ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
class DistilBertModel(BertModel): class DistilBertModel(BertModel):
@ -3962,8 +3958,6 @@ class RobertaModel(BertModel):
bpe_tok_path = self.dir_model / "tokenizer.json" bpe_tok_path = self.dir_model / "tokenizer.json"
if bpe_tok_path.exists(): if bpe_tok_path.exists():
self._set_vocab_gpt2() self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
# we need this to validate the size of the token_type embeddings # we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings # though currently we are passing all zeros to the token_type embeddings
@ -4950,8 +4944,6 @@ class JinaBertV2Model(BertModel):
self.gguf_writer.add_token_type_count(2) self.gguf_writer.add_token_type_count(2)
else: else:
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("OpenELMForCausalLM") @ModelBase.register("OpenELMForCausalLM")
@ -5553,9 +5545,6 @@ class T5Model(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self): def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512") logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@ -5693,9 +5682,6 @@ class T5EncoderModel(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self): def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512") logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@ -6491,8 +6477,8 @@ def parse_args() -> argparse.Namespace:
help="model is executed on big endian machine", help="model is executed on big endian machine",
) )
parser.add_argument( parser.add_argument(
"model", type=Path, "model", type=str,
help="directory containing model file", help="directory containing model file or huggingface repository ID (if --remote)",
nargs="?", nargs="?",
) )
parser.add_argument( parser.add_argument(
@ -6603,18 +6589,20 @@ def main() -> None:
else: else:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
dir_model = args.model
if args.remote: if args.remote:
hf_repo_id = args.model
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
local_dir = snapshot_download( local_dir = snapshot_download(
repo_id=str(dir_model), repo_id=hf_repo_id,
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
dir_model = Path(local_dir) dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}") logger.info(f"Downloaded config and tokenizer to {local_dir}")
else:
hf_repo_id = None
dir_model = Path(args.model)
if not dir_model.is_dir(): if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory') logger.error(f'Error: {dir_model} is not a directory')
sys.exit(1) sys.exit(1)
ftype_map: dict[str, gguf.LlamaFileType] = { ftype_map: dict[str, gguf.LlamaFileType] = {
@ -6634,9 +6622,9 @@ def main() -> None:
if args.outfile is not None: if args.outfile is not None:
fname_out = args.outfile fname_out = args.outfile
elif args.remote: elif hf_repo_id:
# if remote, use the model ID as the output file name # if remote, use the model ID as the output file name
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf") fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
else: else:
fname_out = dir_model fname_out = dir_model
@ -6665,7 +6653,7 @@ def main() -> None:
split_max_tensors=args.split_max_tensors, split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split, small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=str(args.model) if args.remote else None) remote_hf_model_id=hf_repo_id)
if args.vocab_only: if args.vocab_only:
logger.info("Exporting model vocab...") logger.info("Exporting model vocab...")

View File

@ -1,6 +1,6 @@
# Build llama.cpp locally # Build llama.cpp locally
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h). The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.

View File

@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
// max batch size // max batch size
const uint64_t n_batch = params.n_batch; const uint64_t n_batch = params.n_batch;
// get added sep and eos token, if any
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
// tokenize the prompts and trim // tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs; std::vector<std::vector<int32_t>> inputs;
for (const auto & prompt : prompts) { for (const auto & prompt : prompts) {
auto inp = common_tokenize(ctx, prompt, true, true); std::vector<llama_token> inp;
// split classification pairs and insert expected separator tokens
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
std::string final_prompt;
for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i];
if (i != pairs.size() - 1) {
if (!added_eos_token.empty()) {
final_prompt += added_eos_token;
}
if (!added_sep_token.empty()) {
final_prompt += added_sep_token;
}
}
}
inp = common_tokenize(ctx, final_prompt, true, true);
} else {
inp = common_tokenize(ctx, prompt, true, true);
}
if (inp.size() > n_batch) { if (inp.size() > n_batch) {
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch); __func__, (long long int) inp.size(), (long long int) n_batch);
@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
inputs.push_back(inp); inputs.push_back(inp);
} }
// check if the last token is SEP // check if the last token is SEP/EOS
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) { for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
} }
} }

View File

@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) { auto generate = [&](const std::string & prompt) {
std::string response; std::string response;
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0; const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
// tokenize the prompt // tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);

View File

@ -489,6 +489,7 @@ extern "C" {
GGML_OP_UPSCALE, // nearest interpolate GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD, GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D, GGML_OP_PAD_REFLECT_1D,
GGML_OP_ROLL,
GGML_OP_ARANGE, GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT, GGML_OP_ARGSORT,
@ -1801,6 +1802,17 @@ extern "C" {
int p0, int p0,
int p1); int p1);
// Move tensor elements by an offset given for each dimension. Elements that
// are shifted beyond the last position are wrapped around to the beginning.
GGML_API struct ggml_tensor * ggml_roll(
struct ggml_context * ctx,
struct ggml_tensor * a,
int shift0,
int shift1,
int shift2,
int shift3);
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,] // timesteps: [N,]
// return: [N, dim] // return: [N, dim]

View File

@ -286,6 +286,10 @@ function(ggml_add_cpu_backend_variant tag_name)
foreach (feat ${ARGN}) foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON) set(GGML_INTERNAL_${feat} ON)
endforeach() endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
endif() endif()
ggml_add_cpu_backend_variant_impl(${tag_name}) ggml_add_cpu_backend_variant_impl(${tag_name})
@ -337,6 +341,19 @@ if (GGML_CPU_ALL_VARIANTS)
else() else()
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}") message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
endif() endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(power0)
ggml_add_cpu_backend_variant(power7_1 POWER7)
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
ggml_add_cpu_backend_variant(power8_1 POWER8)
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else() else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif() endif()

View File

@ -69,6 +69,9 @@
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic push # pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations" # pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif #endif
namespace fs = std::filesystem; namespace fs = std::filesystem;
@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic pop # pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif #endif
#ifdef _WIN32 #ifdef _WIN32

View File

@ -388,6 +388,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
else() else()
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64) list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
endif() endif()
elseif(GGML_CPU_ALL_VARIANTS)
# Begin with the lowest baseline
set(ARCH_DEFINITIONS "")
# When a feature is selected, bump the MCPU to the first
# version that supported it
foreach(PVER RANGE 7 11)
if(DEFINED GGML_INTERNAL_POWER${PVER})
set(POWERPC_MCPU "power${PVER}")
list(APPEND ARCH_DEFINITIONS GGML_USE_POWER${PVER})
endif()
endforeach()
if (GGML_INTERNAL_VSX)
list(APPEND ARCH_DEFINITIONS GGML_USE_VSX)
list(APPEND ARCH_FLAGS -mvsx)
endif()
if (DEFINED POWERPC_MCPU)
list(APPEND ARCH_FLAGS -mcpu=${POWERPC_MCPU})
endif()
ggml_add_cpu_backend_features(${GGML_CPU_NAME} powerpc ${ARCH_DEFINITIONS})
else() else()
if (GGML_CPU_POWERPC_CPUTYPE) if (GGML_CPU_POWERPC_CPUTYPE)
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE}) list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
@ -465,9 +486,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources: # Fetch KleidiAI sources:
include(FetchContent) include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.6.0") set(KLEIDIAI_COMMIT_TAG "v1.9.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f") set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
if (POLICY CMP0135) if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW) cmake_policy(SET CMP0135 NEW)

View File

@ -256,7 +256,6 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
UNUSED(blocklen); UNUSED(blocklen);
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
for (int c = 0; c < nc; c += ncols_interleaved) { for (int c = 0; c < nc; c += ncols_interleaved) {
@ -294,7 +293,6 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
s += ncols_interleaved; s += ncols_interleaved;
} }
return; return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
float sumf[4]; float sumf[4];
int sumi; int sumi;
@ -341,7 +339,6 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
UNUSED(blocklen); UNUSED(blocklen);
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
for (int c = 0; c < nc; c += ncols_interleaved) { for (int c = 0; c < nc; c += ncols_interleaved) {
@ -384,7 +381,6 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
s += ncols_interleaved; s += ncols_interleaved;
} }
return; return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
float sumf[4]; float sumf[4];
int sumi; int sumi;
@ -432,7 +428,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE)
if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) { if (ggml_cpu_get_sve_cnt() == QK8_0) {
const void * b_ptr = vx; const void * b_ptr = vx;
const void * a_ptr = vy; const void * a_ptr = vy;
float * res_ptr = s; float * res_ptr = s;
@ -547,7 +543,6 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
UNUSED(blocklen); UNUSED(blocklen);
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
const block_q8_0 * a_ptr = (const block_q8_0 *) vy; const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
float * res_ptr = s; float * res_ptr = s;
@ -594,7 +589,6 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
vst1q_f32(res_ptr + x * 4, sumf); vst1q_f32(res_ptr + x * 4, sumf);
} }
return; return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
{ {
float sumf[4]; float sumf[4];
@ -643,8 +637,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
UNUSED(ncols_interleaved); UNUSED(ncols_interleaved);
UNUSED(blocklen); UNUSED(blocklen);
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const void * b_ptr = vx; const void * b_ptr = vx;
const void * a_ptr = vy; const void * a_ptr = vy;
float * res_ptr = s; float * res_ptr = s;
@ -1101,7 +1094,6 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
); );
return; return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
{ {
float sumf[4][4]; float sumf[4][4];
@ -1160,7 +1152,6 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
UNUSED(blocklen); UNUSED(blocklen);
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
const void * b_ptr = vx; const void * b_ptr = vx;
const void * a_ptr = vy; const void * a_ptr = vy;
float * res_ptr = s; float * res_ptr = s;
@ -1557,7 +1548,6 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
); );
return; return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
float sumf[4][4]; float sumf[4][4];
int sumi; int sumi;
@ -1615,7 +1605,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { if (ggml_cpu_get_sve_cnt() == QK8_0) {
const void * b_ptr = vx; const void * b_ptr = vx;
const void * a_ptr = vy; const void * a_ptr = vy;
float * res_ptr = s; float * res_ptr = s;
@ -2083,7 +2073,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
UNUSED(blocklen); UNUSED(blocklen);
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
for (int y = 0; y < nr / 4; y++) { for (int y = 0; y < nr / 4; y++) {
@ -2135,7 +2124,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
} }
} }
return; return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
{ {
float sumf[4][4]; float sumf[4][4];

View File

@ -0,0 +1,82 @@
# include "ggml-backend-impl.h"
#if defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
#if defined(__linux__)
#include <sys/auxv.h>
#endif
#include <string>
struct powerpc_features {
std::string platform = "";
int power_version = -1;
bool has_vsx = false;
powerpc_features() {
#if defined(__linux__)
unsigned long auxval = getauxval(AT_PLATFORM);
if (auxval) {
platform = std::string(reinterpret_cast<const char*>(auxval));
// TBD: Do systems exist that return this in uppercase?
if (platform.substr(0, 5) == "power") {
// Extractt a numeric suffix, if one exists
int vpos = -1;
for (int i = platform.length() - 1; i >= 0; i--) {
if (std::isdigit(platform[i])) {
vpos = i;
} else {
break;
}
}
if (vpos > -1) {
power_version = std::stoi(platform.substr(vpos));
}
}
}
#endif
if (power_version >= 9) {
has_vsx = true;
}
}
};
static int ggml_backend_cpu_powerpc_score() {
int score = 1;
powerpc_features pf;
// Platform scores
#if defined(GGML_USE_POWER7)
if (pf.power_version < 7) { return 0; }
score += 1<<1;
#endif
#if defined(GGML_USE_POWER8)
if (pf.power_version < 8) { return 0; }
score += 1<<2;
#endif
#if defined(GGML_USE_POWER9)
if (pf.power_version < 9) { return 0; }
score += 1<<3;
#endif
#if defined(GGML_USE_POWER10)
if (pf.power_version < 10) { return 0; }
score += 1<<4;
#endif
#if defined(GGML_USE_POWER11)
if (pf.power_version < 11) { return 0; }
score += 1<<5;
#endif
// Feature scores
#if defined(GGML_USE_VSX)
if (!pf.has_vsx) { return 0; }
score += 1<<6;
#endif
return score;
}
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_powerpc_score)
#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)

View File

@ -74,13 +74,8 @@
#if defined(__ARM_ARCH) #if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type { struct ggml_arm_arch_features_type {
int has_neon;
int has_dotprod;
int has_i8mm;
int has_sve;
int sve_cnt; int sve_cnt;
int has_sme; } ggml_arm_arch_features = { 0 };
} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
#endif #endif
@ -678,87 +673,15 @@ bool ggml_is_numa(void) {
#if defined(__linux__) && defined(__aarch64__) #if defined(__linux__) && defined(__aarch64__)
#include <sys/auxv.h> #include <sys/auxv.h>
#elif defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#if !defined(HWCAP2_I8MM)
#define HWCAP2_I8MM (1 << 13)
#endif
#if !defined(HWCAP2_SME)
#define HWCAP2_SME (1 << 23)
#endif #endif
static void ggml_init_arm_arch_features(void) { static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__) #if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
uint32_t hwcap = getauxval(AT_HWCAP);
uint32_t hwcap2 = getauxval(AT_HWCAP2);
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
#if defined(__ARM_FEATURE_SVE)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
#endif #endif
#elif defined(__APPLE__)
int oldp = 0;
size_t size = sizeof(oldp);
if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_neon = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_dotprod = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_i8mm = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_sme = oldp;
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#else
// Run-time CPU feature detection not implemented for this platform, fallback to compile time
#if defined(__ARM_NEON)
ggml_arm_arch_features.has_neon = 1;
#else
ggml_arm_arch_features.has_neon = 0;
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
ggml_arm_arch_features.has_i8mm = 1;
#else
ggml_arm_arch_features.has_i8mm = 0;
#endif
#if defined(__ARM_FEATURE_SVE)
ggml_arm_arch_features.has_sve = 1;
ggml_arm_arch_features.sve_cnt = 16;
#else
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#endif
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
ggml_arm_arch_features.has_sme = 1;
#else
ggml_arm_arch_features.has_sme = 0;
#endif
#endif
} }
#endif
#endif // __ARM_ARCH
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
GGML_ASSERT(!ggml_get_no_alloc(ctx)); GGML_ASSERT(!ggml_get_no_alloc(ctx));
@ -1967,6 +1890,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_pad_reflect_1d(params, tensor); ggml_compute_forward_pad_reflect_1d(params, tensor);
} break; } break;
case GGML_OP_ROLL:
{
ggml_compute_forward_roll(params, tensor);
} break;
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
{ {
ggml_compute_forward_arange(params, tensor); ggml_compute_forward_arange(params, tensor);
@ -2291,6 +2218,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ROLL:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
@ -3443,7 +3371,7 @@ int ggml_cpu_has_vxe(void) {
int ggml_cpu_has_neon(void) { int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON) #if defined(__ARM_ARCH) && defined(__ARM_NEON)
return ggml_arm_arch_features.has_neon; return 1;
#else #else
return 0; return 0;
#endif #endif
@ -3451,7 +3379,7 @@ int ggml_cpu_has_neon(void) {
int ggml_cpu_has_dotprod(void) { int ggml_cpu_has_dotprod(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
return ggml_arm_arch_features.has_dotprod; return 1;
#else #else
return 0; return 0;
#endif #endif
@ -3459,7 +3387,7 @@ int ggml_cpu_has_dotprod(void) {
int ggml_cpu_has_sve(void) { int ggml_cpu_has_sve(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE) #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
return ggml_arm_arch_features.has_sve; return 1;
#else #else
return 0; return 0;
#endif #endif
@ -3467,7 +3395,7 @@ int ggml_cpu_has_sve(void) {
int ggml_cpu_has_matmul_int8(void) { int ggml_cpu_has_matmul_int8(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8) #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
return ggml_arm_arch_features.has_i8mm; return 1;
#else #else
return 0; return 0;
#endif #endif
@ -3483,7 +3411,7 @@ int ggml_cpu_get_sve_cnt(void) {
int ggml_cpu_has_sme(void) { int ggml_cpu_has_sme(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME) #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
return ggml_arm_arch_features.has_sme; return 1;
#else #else
return 0; return 0;
#endif #endif

View File

@ -62,7 +62,7 @@
#define NOINLINE __attribute__((__noinline__)) #define NOINLINE __attribute__((__noinline__))
#endif #endif
#if defined(__ARM_NEON) || defined(__AVX512F__) #if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
#define VECTOR_REGISTERS 32 #define VECTOR_REGISTERS 32
#else #else
#define VECTOR_REGISTERS 16 #define VECTOR_REGISTERS 16
@ -109,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__VXE__) || defined(__VXE2__)
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
#endif
#if defined(__MMA__) #if defined(__MMA__)
typedef vector unsigned char vec_t; typedef vector unsigned char vec_t;
typedef __vector_quad acc_t; typedef __vector_quad acc_t;
@ -162,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
#endif #endif
#endif #endif
#if defined(__VXE__) || defined(__VXE2__)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vec_madd(a, b, c);
}
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM // VECTORIZED HORIZONTAL SUM
@ -178,6 +191,13 @@ inline float hsum(float16x8_t x) {
} }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__VXE__) || defined(__VXE2__)
inline float hsum(float32x4_t x) {
float32x4_t tmp = x + vec_reve(x);
return tmp[0] + tmp[1];
}
#endif
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m128 x) { inline float hsum(__m128 x) {
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
@ -227,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
#endif // _MSC_VER #endif // _MSC_VER
#endif // __ARM_NEON #endif // __ARM_NEON
#if defined(__VXE__) || defined(__VXE2__)
template <> inline float32x4_t load(const ggml_fp16_t * p) {
float tmp[4];
for (int i = 0; i < 4; i++) {
tmp[i] = GGML_FP16_TO_FP32(p[i]);
}
return vec_xl(0, (const float *)(tmp));
}
template <> inline float32x4_t load(const float * p) {
return vec_xl(0, p);
}
#endif
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> inline __m128 load(const float *p) { template <> inline __m128 load(const float *p) {
return _mm_loadu_ps(p); return _mm_loadu_ps(p);
@ -3319,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
(const float *)B, ldb, (const float *)B, ldb,
(float *)C, ldc}; (float *)C, ldc};
return tb.matmul(m, n); return tb.matmul(m, n);
#elif defined(__VXE__) || defined(__VXE2__)
if (n < 4)
return false;
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
#elif defined(__MMA__) #elif defined(__MMA__)
if (k % 8) if (k % 8)
return false; return false;
@ -3410,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
(float *)C, ldc}; (float *)C, ldc};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#elif defined(__VXE__) || defined(__VXE2__)
if (n < 4)
return false;
if (Btype == GGML_TYPE_F16) {
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
k, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
}
#endif #endif
return false; return false;
} }

View File

@ -1,6 +1,11 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <stdbool.h> #include <stdbool.h>
#if defined(__VXE__) || defined(__VXE2__)
#include <vecintrin.h>
#endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif

View File

@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
} }
} }
// ggml_compute_forward_roll
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
if (i < 0) {
return i + ne;
} else if (i >= ne) {
return i - ne;
}
return i;
}
static void ggml_compute_forward_roll_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src_data = (const float *) src0->data;
float * dst_data = (float *) dst->data;
GGML_TENSOR_UNARY_OP_LOCALS
const int s0 = ggml_get_op_params_i32(dst, 0);
const int s1 = ggml_get_op_params_i32(dst, 1);
const int s2 = ggml_get_op_params_i32(dst, 2);
const int s3 = ggml_get_op_params_i32(dst, 3);
const int64_t total = ne1 * ne2 * ne3;
const int64_t per_thread = (total + params->nth) / params->nth;
const int64_t start = params->ith * per_thread;
const int64_t end = std::min(start + per_thread, total);
for (int64_t i = start; i < end; ++i) {
const int64_t i1 = i % ne1;
const int64_t i2 = (i / ne1) % ne2;
const int64_t i3 = i / (ne2 * ne1);
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
const int64_t s = ggml_wrap_index(-s0, ne00);
const int64_t n = ne00 - s;
ggml_vec_cpy_f32(n, dst_row, src_row + s);
ggml_vec_cpy_f32(s, dst_row + n, src_row);
}
}
void ggml_compute_forward_roll(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_roll_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_arange // ggml_compute_forward_arange
static void ggml_compute_forward_arange_f32( static void ggml_compute_forward_arange_f32(

View File

@ -72,6 +72,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
// not realy a GGML_TYPE_Q8_0 but same size. // not realy a GGML_TYPE_Q8_0 but same size.
switch (op->op) { switch (op->op) {
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
return true; return true;
}
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
return true; return true;
}
default: default:
// GGML_ABORT("fatal error"); // GGML_ABORT("fatal error");
break; break;
@ -1305,13 +1316,16 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
int32_t i2; int32_t i2;
}; };
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + GGML_ASSERT(params->wsize >=
n_as * ne12 * sizeof(mmid_row_mapping))); (GGML_PAD(nbw3, sizeof(int64_t)) +
n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
);
auto * wdata = (char *) params->wdata; auto * wdata = (char *)params->wdata;
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
// total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
// src1: float32 => param type // src1: float32 => param type
@ -1397,44 +1411,45 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
} }
}; };
// instance for Q4
static const tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for IQ4
static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
} // namespace ggml::cpu::repack } // namespace ggml::cpu::repack
static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) { static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
// instance for Q4
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for IQ4
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
if (cur->type == GGML_TYPE_Q4_0) { if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
if (cur->ne[1] % 8 == 0) { if (cur->ne[1] % 8 == 0) {
return &ggml::cpu::repack::q4_0_8x8_q8_0; return &q4_0_8x8_q8_0;
} }
} }
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 4 == 0) { if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::repack::q4_0_4x8_q8_0; return &q4_0_4x8_q8_0;
} }
} }
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 4 == 0) { if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::repack::q4_0_4x4_q8_0; return &q4_0_4x4_q8_0;
} }
} }
} else if (cur->type == GGML_TYPE_Q4_K) { } else if (cur->type == GGML_TYPE_Q4_K) {
if (ggml_cpu_has_avx2()) { if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) { if (cur->ne[1] % 8 == 0) {
return &ggml::cpu::repack::q4_K_8x8_q8_K; return &q4_K_8x8_q8_K;
} }
} }
} else if (cur->type == GGML_TYPE_IQ4_NL) { } else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 4 == 0) { if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::repack::iq4_nl_4x4_q8_0; return &iq4_nl_4x4_q8_0;
} }
} }
} }

View File

@ -19,10 +19,10 @@
#endif #endif
#include "ggml-common.h" #include "ggml-common.h"
#include <cstdio>
#include <array> #include <array>
#include <cassert> #include <cassert>
#include <cfloat> #include <cfloat>
#include <cstdio>
#include <string> #include <string>
#include <vector> #include <vector>
@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
return false; return false;
#else #else
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
return true;
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
return true;
#else
return false;
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
} else {
return false;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
} }
@ -252,6 +262,10 @@ static bool fp16_mma_hardware_available(const int cc) {
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
} }
static bool bf16_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) { static bool new_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
@ -362,6 +376,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE> template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) { static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll #pragma unroll
@ -767,21 +801,7 @@ struct ggml_backend_cuda_context {
name(GGML_CUDA_NAME + std::to_string(device)) { name(GGML_CUDA_NAME + std::to_string(device)) {
} }
~ggml_backend_cuda_context() { ~ggml_backend_cuda_context();
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
cudaStream_t stream(int device, int stream) { cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) { if (streams[device][stream] == nullptr) {

View File

@ -0,0 +1,161 @@
#include "conv2d-dw.cuh"
struct conv_params {
int in_w, in_h;
int out_w, out_h;
int kernel_w, kernel_h;
int stride_x, stride_y;
int padding_x, padding_y;
int dilation_x, dilation_y;
int channels, batches;
};
struct kernel_bounds {
int y_min, y_max;
int x_min, x_max;
};
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
kernel_bounds bounds;
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.y_max =
min(params.kernel_h,
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
bounds.x_max =
min(params.kernel_w,
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
y * params.out_w + x;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
out_x = global_idx % params.out_w;
out_y = (global_idx / params.out_w) % params.out_h;
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
n = global_idx / (params.out_w * params.out_h * params.channels);
}
};
struct cwhn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return (ky * params.kernel_w + kx) * params.channels + c;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
x * params.channels + c;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
c = global_idx % params.channels;
out_x = (global_idx / params.channels) % params.out_w;
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
n = global_idx / (params.channels * params.out_w * params.out_h);
}
};
template <typename T, typename Layout>
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
const int in_w, const int in_h, const int out_w, const int out_h,
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
const int channels, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = batches * channels * out_h * out_w;
if (global_idx >= total_elements) {
return;
}
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
int batch_idx, channel_idx, out_y_idx, out_x_idx;
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
T accumulator = 0;
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
accumulator += input_val * kernel_val;
}
}
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
}
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * w_d = (const float *) kernel->data;
const float * x_d = (const float *) input->data;
float * y_d = (float *) dst->data;
const int32_t * p = (const int32_t *) dst->op_params;
const int stride_x = p[0];
const int stride_y = p[1];
const int padding_x = p[2];
const int padding_y = p[3];
const int dilation_x = p[4];
const int dilation_y = p[5];
const int in_w = input->ne[0];
const int in_h = input->ne[1];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int out_w = dst->ne[0];
const int out_h = dst->ne[1];
const int channels = dst->ne[2];
const int batches = dst->ne[3];
cudaStream_t st = ctx.stream();
const int total = batches * channels * out_h * out_w;
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
if (ggml_is_contiguous(input)) {
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else if (ggml_is_contiguous_channels(input)) {
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else {
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
}
}

View File

@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_DW_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -0,0 +1,91 @@
#include <algorithm>
#include "conv2d-transpose.cuh"
#include "ggml.h"
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
const int out_h, const int kernel_w, const int kernel_h, const int stride,
const int c_in, const int c_out, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = out_w * out_h * c_out * batches;
if (global_idx >= total_elements) {
return;
}
const int out_x_idx = global_idx % out_w;
const int out_y_idx = (global_idx / out_w) % out_h;
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
const int n_idx = global_idx / (out_w * out_h * c_out);
float accumulator = 0;
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y_idx - kh;
if (in_y < 0 || in_y % stride) continue;
in_y /= stride;
if (in_y >= in_h) continue;
for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x_idx - kw;
if (in_x < 0 || in_x % stride) continue;
in_x /= stride;
if (in_x >= in_w) continue;
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
const int kernel_idx =
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
float input_val = input[input_idx];
half kern_val = kernel[kernel_idx];
accumulator += input_val * (float) kern_val;
}
}
}
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
}
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * input_data = (const float *) input->data;
float * output_data = (float *) dst->data;
const half * kernel_data = (const half *) kernel->data;
const int input_w = input->ne[0];
const int input_h = input->ne[1];
const int output_w = dst->ne[0];
const int output_h = dst->ne[1];
const int channels_in = input->ne[2];
const int channels_out = kernel->ne[2];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int stride = dst->op_params[0];
const int batches = input->ne[3];
GGML_ASSERT(channels_in == kernel->ne[3]);
GGML_ASSERT(stride > 0);
cudaStream_t st = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(input));
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(ggml_is_contiguous(dst));
const int total = (output_w * output_h * channels_out * batches);
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
}

View File

@ -0,0 +1,4 @@
#include "common.cuh"
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -11,6 +11,8 @@
#include "ggml-cuda/clamp.cuh" #include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh" #include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh" #include "ggml-cuda/convert.cuh"
#include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cpy.cuh"
@ -35,6 +37,7 @@
#include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh" #include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh" #include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh" #include "ggml-cuda/upscale.cuh"
@ -47,6 +50,7 @@
#include <atomic> #include <atomic>
#include <charconv> #include <charconv>
#include <cinttypes> #include <cinttypes>
#include <condition_variable>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <float.h> #include <float.h>
@ -54,9 +58,8 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <stdint.h>
#include <stdio.h>
#include <stdarg.h> #include <stdarg.h>
#include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string> #include <string>
#include <vector> #include <vector>
@ -97,8 +100,7 @@ int ggml_cuda_get_device() {
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
cudaError_t err; cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
{
err = cudaMallocManaged(ptr, size); err = cudaMallocManaged(ptr, size);
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
if (err == hipSuccess) { if (err == hipSuccess) {
@ -116,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
err = cudaMalloc(ptr, size); err = cudaMalloc(ptr, size);
} }
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
} } else {
else
{
err = cudaMalloc(ptr, size); err = cudaMalloc(ptr, size);
} }
return err; return err;
@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device)); return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
} }
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
static std::mutex ggml_cuda_lock;
static std::condition_variable ggml_cuda_lock_cv;
static std::atomic<int> ggml_cuda_lock_counter;
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
// cuda buffer // cuda buffer
struct ggml_backend_cuda_buffer_context { struct ggml_backend_cuda_buffer_context {
@ -1916,8 +1943,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
@ -1925,7 +1951,6 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool any_gpus_with_slow_fp16 = false; bool any_gpus_with_slow_fp16 = false;
bool any_gpus_without_fp16_mma = false;
if (split) { if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@ -1938,14 +1963,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
} }
} else { } else {
const int cc = ggml_cuda_info().devices[ctx.device].cc; const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
} }
// debug helpers // debug helpers
@ -1956,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { if (!split && use_mul_mat_vec) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM // the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
@ -2310,6 +2335,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst); ggml_cuda_op_im2col(ctx, dst);
break; break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst); ggml_cuda_op_conv_transpose_1d(ctx,dst);
break; break;
@ -2322,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst); ggml_cuda_op_sum_rows(ctx, dst);
break; break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst); ggml_cuda_op_ssm_conv(ctx, dst);
break; break;
@ -2685,6 +2719,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured graph_evaluated_or_captured = true; // CUDA graph has been captured
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
ggml_cuda_lock_cv.notify_all();
}
} else { } else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
} }
@ -2760,7 +2799,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
} }
} }
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
}
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
} }
@ -3220,9 +3265,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
} }
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM: case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_ACC: case GGML_OP_ACC:
return true; return true;

View File

@ -0,0 +1,19 @@
#include "mean.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -2,25 +2,26 @@
#include "common.cuh" #include "common.cuh"
#include "mmv.cuh" #include "mmv.cuh"
template <typename T, typename type_acc, int block_size> template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec( static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int64_t row = blockIdx.x; const int row = blockIdx.x;
const int64_t channel_dst = blockIdx.y; const int channel_dst = blockIdx.y;
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int64_t sample_dst = blockIdx.z; const int sample_dst = blockIdx.z;
const int64_t sample_x = sample_dst / sample_ratio; const int sample_x = sample_dst / sample_ratio;
const int64_t sample_y = sample_dst; const int sample_y = sample_dst;
const int tid = threadIdx.x; const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += sample_y *stride_sample_y + channel_y *stride_channel_y; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y; const float2 * y2 = (const float2 *) y;
@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
__syncthreads(); __syncthreads();
} }
float sumf = 0.0f; float sumf[ncols_dst] = {0.0f};
if constexpr (std::is_same<T, float>::value) { if constexpr (std::is_same<T, float>::value) {
const float2 * x2 = (const float2 *) x; const float2 * x2 = (const float2 *) x;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2]; const float2 tmpx = x2[col2];
const float2 tmpy = y2[col2];
sumf += tmpx.x*tmpy.x; #pragma unroll
sumf += tmpx.y*tmpy.y; for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x*tmpy.x;
sumf[j] += tmpx.y*tmpy.y;
}
} }
} else if constexpr (std::is_same<T, half>::value) { } else if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x; const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) { if (std::is_same<type_acc, float>::value) {
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]); const float2 tmpx = __half22float2(x2[col2]);
const float2 tmpy = y2[col2];
sumf += tmpx.x * tmpy.x; #pragma unroll
sumf += tmpx.y * tmpy.y; for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x * tmpy.x;
sumf[j] += tmpx.y * tmpy.y;
}
} }
} else { } else {
#ifdef FP16_AVAILABLE #ifdef FP16_AVAILABLE
half2 sumh2 = make_half2(0.0f, 0.0f); half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmp = y2[col2]; const half2 tmpx = x2[col2];
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
}
} }
sumf = __low2float(sumh2) + __high2float(sumh2); #pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
}
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
} else if constexpr (std::is_same<T, nv_bfloat16>::value) { } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
const int * x2 = (const int *) x; const int * x2 = (const int *) x;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2]; const int tmpx = x2[col2];
const float2 tmpy = y2[col2]; #pragma unroll
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x; for (int j = 0; j < ncols_dst; ++j) {
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
}
} }
} else { } else {
static_assert(std::is_same<T, void>::value, "unsupported type"); static_assert(std::is_same<T, void>::value, "unsupported type");
} }
sumf = warp_reduce_sum<warp_size>(sumf); #pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if (block_size > warp_size) { if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf; buf_iw[tid/warp_size] = sumf[j];
__syncthreads(); __syncthreads();
if (tid >= warp_size) { if (tid < warp_size) {
return; sumf[j] = buf_iw[tid];
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
}
if (j < ncols_dst) {
__syncthreads();
}
} }
sumf = buf_iw[tid];
sumf = warp_reduce_sum<warp_size>(sumf);
} }
if (tid != 0) { if (tid >= ncols_dst) {
return; return;
} }
dst[row] = sumf; dst[tid*stride_col_dst + row] = sumf[tid];
} }
template <typename T, typename type_acc> template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_cuda( static void launch_mul_mat_vec_cuda(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x; const int64_t channel_ratio = nchannels_dst / nchannels_x;
@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
const dim3 block_dims(block_size_best, 1, 1); const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) { switch (block_size_best) {
case 32: { case 32: {
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 64: { case 64: {
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 96: { case 96: {
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 128: { case 128: {
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 160: { case 160: {
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 192: { case 192: {
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 224: { case 224: {
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 256: { case 256: {
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
} }
} }
template <typename T, typename type_acc>
static void mul_mat_vec_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
template<typename T> template<typename T>
static void mul_mat_vec_cuda( static void mul_mat_vec_cuda(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) { enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same<T, half>::value) { if constexpr(std::is_same<T, half>::value) {
if (prec == GGML_PREC_DEFAULT) { if (prec == GGML_PREC_DEFAULT) {
launch_mul_mat_vec_cuda<T, half> mul_mat_vec_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return; return;
} }
} }
launch_mul_mat_vec_cuda<T, float> mul_mat_vec_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} }
@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12; const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(ncols_dst == 1); GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data; const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break; } break;
@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low; const int64_t row_diff = row_high - row_low;
GGML_ASSERT(src1_ncols == 1); const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices // ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00; const int64_t stride_row = ne00;
const int64_t stride_col_y = ne10;
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
const int64_t nchannels_x = 1; const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1; const int64_t nchannels_y = 1;
const int64_t nchannels_dst = 1; const int64_t nchannels_dst = 1;
@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i; const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i; const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
@ -334,3 +441,48 @@ void ggml_cuda_op_mul_mat_vec(
GGML_UNUSED(src1_ncols); GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size); GGML_UNUSED(src1_padded_row_size);
} }
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return ne11 <= 8;
}
if (cc >= GGML_CUDA_CC_TURING) {
return ne11 <= 4;
}
return ne11 <= 3;
}
return ne11 <= 8;
case GGML_TYPE_F16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
if (fp16_mma_hardware_available(cc)) {
return src0_small && ne11 <= 3;
}
return ne11 <= 8;
}
return ne11 <= 8;
case GGML_TYPE_BF16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
if (bf16_mma_hardware_available(cc)) {
return src0_small && ne11 <= 3;
}
return ne11 <= 8;
}
return ne11 <= 8;
default:
return false;
}
}

View File

@ -1,8 +1,5 @@
#include "common.cuh" #include "common.cuh"
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
#define MMV_MAX_ROWS 512
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec( void ggml_cuda_op_mul_mat_vec(
@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream); const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);

View File

@ -1,25 +1,9 @@
#include "sumrows.cuh" #include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1); const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} }
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0]; const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} }

View File

@ -1,5 +1,4 @@
#include "common.cuh" #include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
int mtl_device_ref_count; int mtl_device_ref_count;
id<MTLLibrary> mtl_library; id<MTLLibrary> mtl_library;
NSLock * mtl_lock;
bool has_simdgroup_reduction; bool has_simdgroup_reduction;
bool has_simdgroup_mm; bool has_simdgroup_mm;
bool has_residency_sets; bool has_residency_sets;
bool has_bfloat; bool has_bfloat;
bool use_bfloat; bool use_bfloat;
size_t max_size;
char name[128]; char name[128];
} g_ggml_ctx_dev_main = { } g_ggml_ctx_dev_main = {
/*.mtl_device =*/ nil, /*.mtl_device =*/ nil,
/*.mtl_device_ref_count =*/ 0, /*.mtl_device_ref_count =*/ 0,
/*.mtl_library =*/ nil, /*.mtl_library =*/ nil,
/*.mtl_lock =*/ nil,
/*.has_simdgroup_reduction =*/ false, /*.has_simdgroup_reduction =*/ false,
/*.has_simdgroup_mm =*/ false, /*.has_simdgroup_mm =*/ false,
/*.has_residency_sets =*/ false, /*.has_residency_sets =*/ false,
/*.has_bfloat =*/ false, /*.has_bfloat =*/ false,
/*.use_bfloat =*/ false, /*.use_bfloat =*/ false,
/*.max_size =*/ 0,
/*.name =*/ "", /*.name =*/ "",
}; };
@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) { static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
assert(ctx != NULL); assert(ctx != NULL);
if (ctx->mtl_lock == nil) {
ctx->mtl_lock = [[NSLock alloc] init];
}
if (ctx->mtl_device == nil) { if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice(); ctx->mtl_device = MTLCreateSystemDefaultDevice();
} }
@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
ctx->use_bfloat = false; ctx->use_bfloat = false;
#endif #endif
ctx->max_size = ctx->mtl_device.maxBufferLength;
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
} }
@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--; ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) { if (ctx->mtl_device_ref_count == 0) {
if (ctx->mtl_lock) {
[ctx->mtl_lock release];
ctx->mtl_lock = nil;
}
if (ctx->mtl_library) { if (ctx->mtl_library) {
[ctx->mtl_library release]; [ctx->mtl_library release];
ctx->mtl_library = nil; ctx->mtl_library = nil;
@ -978,7 +995,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
struct ggml_backend_metal_device_context * ctx_dev = dev->context; struct ggml_backend_metal_device_context * ctx_dev = dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev); id<MTLDevice> device = ctx_dev->mtl_device;
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
@ -992,9 +1009,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
// load library // load library
{
[ctx_dev->mtl_lock lock];
if (ctx_dev->mtl_library == nil) { if (ctx_dev->mtl_library == nil) {
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
} }
[ctx_dev->mtl_lock unlock];
}
id<MTLLibrary> metal_library = ctx_dev->mtl_library; id<MTLLibrary> metal_library = ctx_dev->mtl_library;
if (metal_library == nil) { if (metal_library == nil) {
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__); GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@ -5313,7 +5337,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
} }
ggml_backend_metal_buffer_rset_free(ctx); ggml_backend_metal_buffer_rset_free(ctx);
ggml_backend_metal_device_rel(buffer->buft->device->context);
if (ctx->owned) { if (ctx->owned) {
#if TARGET_OS_OSX #if TARGET_OS_OSX
@ -5422,7 +5445,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
} }
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_ASSERT(ctx_dev->mtl_device != nil);
id<MTLDevice> device = ctx_dev->mtl_device;
ctx->all_data = ggml_metal_host_malloc(size_aligned); ctx->all_data = ggml_metal_host_malloc(size_aligned);
ctx->all_size = size_aligned; ctx->all_size = size_aligned;
@ -5445,14 +5471,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
free(ctx); free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL; return NULL;
} }
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx); free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL; return NULL;
} }
@ -5463,17 +5487,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 32; return 32;
GGML_UNUSED(buft); GGML_UNUSED(buft);
} }
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context); const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
const size_t max_size = device.maxBufferLength;
ggml_backend_metal_device_rel(buft->device->context);
return max_size; return max_size;
GGML_UNUSED(buft);
} }
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@ -5546,7 +5567,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
} }
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_ASSERT(ctx_dev->mtl_device != nil);
id<MTLDevice> device = ctx_dev->mtl_device;
// the buffer fits into the max buffer size allowed by the device // the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) { if (size_aligned <= device.maxBufferLength) {
@ -5602,7 +5626,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx); free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL; return NULL;
} }
@ -5619,9 +5642,7 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
static void ggml_backend_metal_free(ggml_backend_t backend) { static void ggml_backend_metal_free(ggml_backend_t backend) {
struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
ggml_backend_metal_device_rel(ctx_dev);
ggml_metal_free(ctx); ggml_metal_free(ctx);
free(backend); free(backend);
@ -5761,6 +5782,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
GGML_ASSERT(ctx_dev->mtl_device != nil);
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
} }
@ -5780,10 +5803,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
} }
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
// acq/rel just to populate ctx->name in case it hasn't been done yet
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
ggml_backend_metal_device_acq(ctx_dev);
ggml_backend_metal_device_rel(ctx_dev);
return ctx_dev->name; return ctx_dev->name;
} }
@ -5791,12 +5811,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
if (@available(macOS 10.12, iOS 16.0, *)) { if (@available(macOS 10.12, iOS 16.0, *)) {
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev); id<MTLDevice> device = ctx_dev->mtl_device;
*total = device.recommendedMaxWorkingSetSize; *total = device.recommendedMaxWorkingSetSize;
*free = *total - device.currentAllocatedSize; *free = *total - device.currentAllocatedSize;
ggml_backend_metal_device_rel(ctx_dev);
} else { } else {
*free = 1; *free = 1;
*total = 1; *total = 1;
@ -5874,7 +5892,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
} }
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_ASSERT(ctx_dev->mtl_device != nil);
id<MTLDevice> device = ctx_dev->mtl_device;
// the buffer fits into the max buffer size allowed by the device // the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) { if (size_aligned <= device.maxBufferLength) {
@ -5930,7 +5951,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx); free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL; return NULL;
} }
@ -5944,7 +5964,8 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
} }
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || return
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
GGML_UNUSED(dev); GGML_UNUSED(dev);
@ -6030,8 +6051,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
/* .get_proc_address = */ ggml_backend_metal_get_proc_address, /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
}; };
// called upon program exit
static void ggml_metal_cleanup(void) {
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
}
// TODO: make thread-safe
ggml_backend_reg_t ggml_backend_metal_reg(void) { ggml_backend_reg_t ggml_backend_metal_reg(void) {
// TODO: make this thread-safe somehow? ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
// register cleanup callback
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
atexit(ggml_metal_cleanup);
{ {
g_ggml_backend_metal_reg = (struct ggml_backend_reg) { g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
/* .api_version = */ GGML_BACKEND_API_VERSION, /* .api_version = */ GGML_BACKEND_API_VERSION,

View File

@ -225,9 +225,9 @@ struct bin_bcast_sycl {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * stream,
sycl::range<3>(1, 1, block_size), sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
sycl::range<3>(1, 1, block_size)), sycl::range<3>(1, 1, block_size)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>( k_bin_bcast_unravel<bin_op>(
@ -246,9 +246,8 @@ struct bin_bcast_sycl {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13, ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13, s1, s2, s3, s01, s02, s03, s11, s12, s13,

View File

@ -89,32 +89,23 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
sycl::range<3> gridDim(ne2, ne1, num_blocks); sycl::range<3> gridDim(ne2, ne1, num_blocks);
switch (dim) { switch (dim) {
case 0: case 0:
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
});
break; break;
case 1: case 1:
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
});
break; break;
// dim >=2 will be dispatched to the default path // dim >=2 will be dispatched to the default path
default: default:
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
});
break; break;
} }
} }
@ -129,26 +120,22 @@ static void concat_f32_sycl_non_cont(
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2, int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
uint64_t nb3, int32_t dim) { uint64_t nb3, int32_t dim) {
sycl::range<3> gridDim(ne3, ne2, ne1); sycl::range<3> gridDim(ne3, ne2, ne1);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
int64_t i3 = item_ct1.get_group(0); int64_t i3 = item_ct1.get_group(0);
int64_t i2 = item_ct1.get_group(1); int64_t i2 = item_ct1.get_group(1);
int64_t i1 = item_ct1.get_group(2); int64_t i1 = item_ct1.get_group(2);
int64_t o[4] = {0, 0, 0, 0}; int64_t o[4] = { 0, 0, 0, 0 };
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
const float *x; const float * x;
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
i0 += item_ct1.get_local_range(2)) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 + x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
(i0)*nb00);
} else { } else {
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10); (i0 - o[0]) * nb10);
} }
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);

View File

@ -59,15 +59,9 @@ static void conv_transpose_1d_f32_f32_sycl(
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE; const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE); const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
const sycl::range<3> block_nums(1, 1, num_blocks); const sycl::range<3> block_nums(1, 1, num_blocks);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>( conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
block_nums * block_dims, block_dims), item_ct1);
[=](sycl::nd_item<3> item_ct1) {
conv_transpose_1d_kernel(
s0, output_size,
src0_ne0, src0_ne1, src0_ne2,
src1_ne0, dst_ne0,
src0, src1, dst, item_ct1);
}); });
} }

View File

@ -33,14 +33,11 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>( stream,
sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
});
} }
} }
@ -53,24 +50,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
} }
#endif #endif
@ -85,24 +76,18 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
} }
#endif #endif
} }
@ -116,12 +101,9 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_0(vx, y, nb32, item_ct1);
});
} }
} }
@ -135,13 +117,12 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
int constexpr WARP_K = WARP_SIZE * QK4_0; int constexpr WARP_K = WARP_SIZE * QK4_0;
const int n_warp = (k + WARP_K - 1) / WARP_K; const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % 2 == 0); GGML_ASSERT(k % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl_parallel_for(stream,
sycl::range<3>(1, 1, WARP_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)), sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_block_q4_0_reorder(vx, y, k, item_ct1); dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
}); });
} }
template <typename dst_t> template <typename dst_t>
@ -153,12 +134,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_1(vx, y, nb32, item_ct1);
});
} }
} }
@ -171,11 +149,10 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh); sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1); dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
}); });
@ -191,10 +168,10 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh); sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) { [=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
}); });
@ -210,24 +187,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
} }
#endif #endif
@ -242,24 +213,18 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
} }
#endif #endif
@ -271,7 +236,7 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
} }
@ -284,15 +249,10 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_s(
vx, y, item_ct1, iq1s_grid_gpu
);
});
}); });
} }
} }
@ -305,15 +265,10 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_m(
vx, y, item_ct1, iq1s_grid_gpu
);
});
}); });
} }
} }
@ -326,14 +281,11 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xxs( dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
vx, y, item_ct1, iq2xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
}); });
}); });
} }
@ -347,14 +299,11 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xs( dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
vx, y, item_ct1, iq2xs_grid,
ksigns_iq2xs, kmask_iq2xs);
}); });
}); });
} }
@ -368,13 +317,10 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_s(vx, y, item_ct1);
});
}); });
} }
} }
@ -388,14 +334,11 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_xxs( dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
vx, y, item_ct1, iq3xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
}); });
}); });
} }
@ -409,14 +352,10 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_s(
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
});
}); });
} }
} }
@ -432,14 +371,11 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * cgh,
sycl::range<3>(1, 1, 32), sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_xs(vx, y, item_ct1);
});
}); });
} }
#endif #endif
@ -453,14 +389,11 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * cgh,
sycl::range<3>(1, 1, 32), sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_nl(vx, y, item_ct1);
});
}); });
} }
} }

View File

@ -413,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -431,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -449,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -465,7 +468,7 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0; const int num_blocks = ne / QK8_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -477,7 +480,7 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -490,7 +493,7 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0; const int num_blocks = ne / QK4_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -502,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -516,7 +520,7 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1; const int num_blocks = ne / QK4_1;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -528,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -542,7 +547,7 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0; const int num_blocks = ne / QK5_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -554,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -568,7 +574,7 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1; const int num_blocks = ne / QK5_1;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -580,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -594,10 +601,10 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL; const int num_blocks = ne / QK4_NL;
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -609,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -628,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
// dpct::has_capability_or_fail(stream->get_device(), // dpct::has_capability_or_fail(stream->get_device(),
// {sycl::aspect::fp16}); // {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -647,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
// dpct::has_capability_or_fail(stream->get_device(), // dpct::has_capability_or_fail(stream->get_device(),
// {sycl::aspect::fp16}); // {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -662,10 +672,12 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -675,10 +687,12 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -689,10 +703,12 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -702,9 +718,12 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -715,9 +734,12 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }

View File

@ -208,11 +208,9 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
nrows, item_ct1);
}); });
} }
} }
@ -877,11 +875,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>( dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
vx, y, dst, ncols, nrows, item_ct1); nrows, item_ct1);
}); });
} }
} }
@ -900,11 +897,9 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>( dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -921,11 +916,9 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>( dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -942,11 +935,9 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>( dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -963,11 +954,9 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>( dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -984,11 +973,9 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>( dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -1002,8 +989,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1018,8 +1004,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1034,8 +1019,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1047,8 +1031,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
}); });
@ -1063,8 +1046,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
}); });

View File

@ -13,10 +13,10 @@
#ifndef GGML_SYCL_DPCT_HELPER_HPP #ifndef GGML_SYCL_DPCT_HELPER_HPP
#define GGML_SYCL_DPCT_HELPER_HPP #define GGML_SYCL_DPCT_HELPER_HPP
#include <map>
#include <sycl/sycl.hpp> #include <sycl/sycl.hpp>
#include <sycl/half_type.hpp> #include <sycl/half_type.hpp>
#include <syclcompat/math.hpp> #include <syclcompat/math.hpp>
#include <map>
#ifdef GGML_SYCL_USE_INTEL_ONEMKL #ifdef GGML_SYCL_USE_INTEL_ONEMKL
#include <oneapi/mkl.hpp> #include <oneapi/mkl.hpp>
@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
#endif #endif
} }
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
namespace syclex = sycl::ext::oneapi::experimental;
#endif
template <int NR, typename Func>
__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::nd_launch(cgh, nd_range, func);
#else
cgh.parallel_for(nd_range, func);
#endif
}
template <int NR, typename Func>
__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::nd_launch(*q, nd_range, func);
#else
q->parallel_for(nd_range, func);
#endif
}
template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::submit(*stream, func);
#else
stream->submit(func);
#endif
}
namespace dpct namespace dpct
{ {
typedef sycl::queue *queue_ptr; typedef sycl::queue *queue_ptr;

View File

@ -329,13 +329,11 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int ne12, const int nb1, const int nb2, const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) { const int offset, queue_ptr stream) {
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
item_ct1);
}); });
} }
@ -343,46 +341,39 @@ template<typename T>
static void gelu_sycl(const T *x, T *dst, const int k, static void gelu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
gelu(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void silu_sycl(const T *x, T *dst, const int k, static void silu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
silu(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now // hard code for now
const int num_blocks = ceil_div(k, 256); const int num_blocks = ceil_div(k, 256);
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
sgn(x, dst, k, item_ct1); [=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
});
} }
template<typename T> template<typename T>
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now // hard code for now
const int num_blocks = ceil_div(k, 256); const int num_blocks = ceil_div(k, 256);
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { stream,
abs_op(x, dst, k, item_ct1); sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
}); [=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
} }
@ -390,23 +381,20 @@ template<typename T>
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now // hard code for now
const int num_blocks = ceil_div(k, 256); const int num_blocks = ceil_div(k, 256);
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { stream,
elu_op(x, dst, k, item_ct1); sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
}); [=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
} }
template<typename T> template<typename T>
static void gelu_quick_sycl(const T *x, T *dst, const int k, static void gelu_quick_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
gelu_quick(x, dst, k, item_ct1);
});
} }
@ -414,169 +402,133 @@ template<typename T>
static void gelu_erf_sycl(const T *x, T *dst, const int k, static void gelu_erf_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
gelu_erf(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void tanh_sycl(const T *x, T *dst, const int k, static void tanh_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
tanh(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void relu_sycl(const T *x, T *dst, const int k, static void relu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
relu(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void hardsigmoid_sycl(const T *x, T *dst, const int k, static void hardsigmoid_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * stream,
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
hardsigmoid(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void hardswish_sycl(const T *x, T *dst, const int k, static void hardswish_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * stream,
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
hardswish(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void exp_sycl(const T *x, T *dst, const int k, static void exp_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
exp(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void log_sycl(const T *x, T *dst, const int k, static void log_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
log(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void neg_sycl(const T *x, T *dst, const int k, static void neg_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
neg(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void step_sycl(const T *x, T *dst, const int k, static void step_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); });
step(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sigmoid_sycl(const T *x, T *dst, const int k, static void sigmoid_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * stream,
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); });
sigmoid(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sqrt_sycl(const T *x, T *dst, const int k, static void sqrt_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
sqrt(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sin_sycl(const T *x, T *dst, const int k, static void sin_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
sin(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void cos_sycl(const T *x, T *dst, const int k, static void cos_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
cos(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
@ -584,26 +536,20 @@ static void leaky_relu_sycl(const T *x, T *dst, const int k,
const float negative_slope, const float negative_slope,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
leaky_relu(x, dst, k, negative_slope, item_ct1);
});
} }
template<typename T> template<typename T>
static void sqr_sycl(const T *x, T *dst, const int k, static void sqr_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); });
sqr(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
@ -614,9 +560,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
int dst_size = ne10 * ne11 * ne12 * ne13; int dst_size = ne10 * ne11 * ne12 * ne13;
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for<1>(
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) {
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
}); });
} }
@ -627,12 +572,10 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
const int ne1, const int ne2, queue_ptr stream) { const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
sycl::range<3> gridDim(ne2, ne1, num_blocks); sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
});
} }
template<typename T> template<typename T>
@ -640,13 +583,10 @@ static void clamp_sycl(const T *x, T *dst, const float min,
const float max, const int k, const float max, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
clamp(x, dst, min, max, k, item_ct1);
});
} }
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {

View File

@ -60,54 +60,6 @@ static void k_get_rows(
dst_row[iybs + iqs + y_offset] = v.y(); dst_row[iybs + iqs + y_offset] = v.y();
} }
template<int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder, typename dst_t>
static void k_get_rows_reorder(
const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12,
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2)) *
2;
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1);
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
item_ct1.get_local_id(0)) /
ne12;
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
item_ct1.get_local_id(0)) %
ne12;
if (i00 >= ne00) {
return;
}
auto ncols = ne00;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const int src0_off = i01 * ncols + i00;
const int ib = src0_off / QK4_0; // block index
const int iqs = (i00%qk)/qr; // x quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v);
dst_row[iybs + iqs + 0] = v.x();
dst_row[iybs + iqs + y_offset] = v.y();
GGML_UNUSED(nb01);
GGML_UNUSED(nb02);
GGML_UNUSED(nb03);
}
template<typename src0_t, typename dst_t> template<typename src0_t, typename dst_t>
static void k_get_rows_float( static void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst, const src0_t * src0, const int32_t * src1, dst_t * dst,
@ -166,58 +118,15 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
GGML_ASSERT(ne00 % 2 == 0); GGML_ASSERT(ne00 % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) { k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
k_get_rows<qk, qr, dq>( item_ct1);
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
}); });
GGML_UNUSED(dst); GGML_UNUSED(dst);
GGML_UNUSED(ctx); GGML_UNUSED(ctx);
} }
template <int qk, int qr, dequantize_kernel_t_reorder dq_reorder>
static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const void *src0_dd,
const int32_t *src1_dd, float *dst_dd,
queue_ptr stream) {
GGML_TENSOR_BINARY_OP_LOCALS
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const size_t s1 = nb1 / ggml_element_size(dst);
const size_t s2 = nb2 / ggml_element_size(dst);
const size_t s3 = nb3 / ggml_element_size(dst);
const size_t s10 = nb10 / ggml_element_size(src1);
const size_t s11 = nb11 / ggml_element_size(src1);
const size_t s12 = nb12 / ggml_element_size(src1);
//const size_t s13 = nb13 / ggml_element_size(src1);
GGML_ASSERT(ne00 % 2 == 0);
const uint8_t* src0_q = (const uint8_t*)src0_dd;
const size_t ncols = ne00;
const size_t nrows = ne01;
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
k_get_rows_reorder<qk, qr, dq_reorder>(
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
});
GGML_UNUSED(dst);
GGML_UNUSED(ctx);
}
template <typename src0_t> template <typename src0_t>
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst, const ggml_tensor *src1, ggml_tensor *dst,
@ -245,9 +154,8 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
}); });
@ -277,13 +185,8 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
src1_i32, (float *)dst->data, ctx.stream()); src1_i32, (float *)dst->data, ctx.stream());
break; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
} else {
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream()); src1_i32, (float *)dst->data, ctx.stream());
}
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,

View File

@ -1887,13 +1887,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const size_t shared_mem = ncols_pad * sizeof(int); const size_t shared_mem = ncols_pad * sizeof(int);
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1( sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh); sycl::range<1>(shared_mem), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>( k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1, x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>() dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@ -1901,13 +1900,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
}); });
}); });
} else if (order == GGML_SORT_ORDER_DESC) { } else if (order == GGML_SORT_ORDER_DESC) {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1( sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh); sycl::range<1>(shared_mem), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>( k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1, x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>() dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@ -1925,15 +1923,13 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
const sycl::range<3> block_nums(1, nrows, 1); const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = 256 * sizeof(float); const size_t shared_mem = 256 * sizeof(float);
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_data( sycl::local_accessor<float, 1> shared_data(
sycl::range<1>(shared_mem/sizeof(float)), cgh); sycl::range<1>(shared_mem/sizeof(float)), cgh);
sycl::local_accessor<int, 1> shared_indices( sycl::local_accessor<int, 1> shared_indices(
sycl::range<1>(shared_mem/sizeof(float)), cgh); sycl::range<1>(shared_mem/sizeof(float)), cgh);
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2); const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1); const int row = item_ct1.get_global_id(1);
@ -1952,7 +1948,7 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
shared_indices[tid] = max_idx; shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
for (int stride = 256/2; stride > 0; stride >>= 1) { for (int stride = 256 / 2; stride > 0; stride >>= 1) {
if (tid < stride) { if (tid < stride) {
float val1 = shared_data[tid]; float val1 = shared_data[tid];
float val2 = shared_data[tid + stride]; float val2 = shared_data[tid + stride];
@ -1964,7 +1960,6 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
} }
if (tid == 0) { if (tid == 0) {
dst[row] = shared_indices[0]; dst[row] = shared_indices[0];
} }
@ -2952,7 +2947,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
void ** ptrs_dst_get = ptrs_dst.get(); void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
}); });
@ -3456,7 +3451,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
{ {
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh); sycl::local_accessor<int, 0> src1_row_acc(cgh);
char *__restrict src1_contiguous_get = char *__restrict src1_contiguous_get =
@ -3468,9 +3463,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
size_t ids_nb_ct6 = ids->nb[1]; size_t ids_nb_ct6 = ids->nb[1];
size_t ids_nb_ct7 = ids->nb[0]; size_t ids_nb_ct7 = ids->nb[0];
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous( k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get, src1_original, src1_contiguous_get,
dev_cur_src1_row_get, dev_cur_src1_row_get,
@ -3501,15 +3495,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
{ {
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
sycl::range<3> grid_dims(1, 1, num_src1_rows); sycl::range<3> grid_dims(1, 1, num_src1_rows);
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
const char *__restrict dst_contiguous_get = const char *__restrict dst_contiguous_get =
dst_contiguous.get(); dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get = const mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get(); dev_row_mapping.get();
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original, k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get, dst_contiguous_get,
dev_row_mapping_get, dev_row_mapping_get,

View File

@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
const u_int n_seq_tokens = T / B; const u_int n_seq_tokens = T / B;
sycl::range<1> block_dims((C / H)); sycl::range<1> block_dims((C / H));
sycl::range<1> grid_dims((B * H)); sycl::range<1> grid_dims((B * H));
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
/* local memory accessors*/ /* local memory accessors*/
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh); auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh); auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh); auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
u_int tid = item.get_local_id(0); u_int tid = item.get_local_id(0);
u_int bid = item.get_group(0); u_int bid = item.get_group(0);

View File

@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
const int64_t CHW = IC * KH * KW; const int64_t CHW = IC * KH * KW;
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
p0, p1, d0, d1, item_ct1); p0, p1, d0, d1, item_ct1);
}); });

View File

@ -1818,7 +1818,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
@ -1829,9 +1829,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_0<need_check>( mul_mat_q4_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -1853,7 +1852,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
@ -1864,9 +1863,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_0<need_check>( mul_mat_q4_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -1933,7 +1931,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
@ -1944,9 +1942,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_1<need_check>( mul_mat_q4_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -1968,7 +1965,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
@ -1979,9 +1976,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_1<need_check>( mul_mat_q4_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2048,7 +2044,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
@ -2059,9 +2055,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_0<need_check>( mul_mat_q5_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2083,7 +2078,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
@ -2094,9 +2089,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_0<need_check>( mul_mat_q5_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2163,7 +2157,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
@ -2174,9 +2168,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_1<need_check>( mul_mat_q5_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2198,7 +2191,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
@ -2209,9 +2202,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_1<need_check>( mul_mat_q5_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2278,7 +2270,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
@ -2289,9 +2281,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q8_0<need_check>( mul_mat_q8_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2313,7 +2304,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
@ -2324,9 +2315,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q8_0<need_check>( mul_mat_q8_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2393,7 +2383,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
@ -2406,9 +2396,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q2_K<need_check>( mul_mat_q2_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2431,7 +2420,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
@ -2444,9 +2433,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q2_K<need_check>( mul_mat_q2_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2516,7 +2504,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
@ -2531,9 +2519,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q3_K<need_check>( mul_mat_q3_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2557,7 +2544,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
@ -2572,9 +2559,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q3_K<need_check>( mul_mat_q3_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2644,7 +2630,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
@ -2657,9 +2643,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_K<need_check>( mul_mat_q4_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2682,7 +2667,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
@ -2695,9 +2680,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_K<need_check>( mul_mat_q4_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2765,7 +2749,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
@ -2778,9 +2762,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_K<need_check>( mul_mat_q5_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2803,7 +2786,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
@ -2816,9 +2799,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_K<need_check>( mul_mat_q5_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2886,7 +2868,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
@ -2899,9 +2881,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q6_K<need_check>( mul_mat_q6_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2924,7 +2905,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
@ -2937,9 +2918,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q6_K<need_check>( mul_mat_q6_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,

View File

@ -544,8 +544,8 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows, mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
nd_item); nd_item);
@ -561,8 +561,8 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>( mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -580,15 +580,10 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -604,15 +599,10 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -628,15 +618,10 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -652,15 +637,10 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -676,15 +656,10 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -700,15 +675,10 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -724,15 +694,10 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -750,11 +715,11 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
nrows, nd_item); nd_item);
}); });
}); });
} }
@ -769,15 +734,10 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -794,8 +754,8 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows, mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
nd_item); nd_item);
@ -811,15 +771,10 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -836,13 +791,11 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { nrows, item_ct1);
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -857,13 +810,11 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { nrows, item_ct1);
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -878,14 +829,11 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -900,14 +848,11 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
[=](sycl::nd_item<3> item_ct1) nrows, item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -922,14 +867,11 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -944,14 +886,11 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -966,13 +905,11 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { item_ct1);
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -987,14 +924,11 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -1009,14 +943,11 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
[=](sycl::nd_item<3> item_ct1) nrows, item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }

View File

@ -254,12 +254,11 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
GGML_ASSERT(ncols % WARP_SIZE == 0); GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) { if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { nullptr, WARP_SIZE);
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
}); });
}); });
} }
@ -272,14 +271,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1( sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
sycl::range<1>(work_group_size / WARP_SIZE), cgh); sycl::range<1>(work_group_size / WARP_SIZE), cgh);
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { get_pointer(s_sum_acc_ct1), work_group_size);
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
}); });
}); });
} }
@ -290,16 +288,12 @@ static void group_norm_f32_sycl(const float* x, float* dst,
const int ne_elements, queue_ptr stream, int device) { const int ne_elements, queue_ptr stream, int device) {
if (group_size < 1024) { if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
const float eps_ct4 = eps; const float eps_ct4 = eps;
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
block_dims), group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
[=](sycl::nd_item<3> item_ct1) WARP_SIZE);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE);
}); });
}); });
} }
@ -313,19 +307,15 @@ static void group_norm_f32_sycl(const float* x, float* dst,
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh); cgh);
const float eps_ct4 = eps; const float eps_ct4 = eps;
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
block_dims), group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements,
eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
}); });
@ -340,51 +330,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
const sycl::range<3> global_dims(nsamples, nchannels, nrows); const sycl::range<3> global_dims(nsamples, nchannels, nrows);
if (ncols < 1024) { if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); nullptr, WARP_SIZE);
}); });
}); });
@ -398,21 +347,53 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh); cgh);
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
block_dims), rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
}); });
} }
} }
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
work_group_size);
});
});
}
}
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];

View File

@ -235,9 +235,10 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} else { } else {
/* /*
@ -245,9 +246,10 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} }
} }
@ -267,14 +269,16 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} else { } else {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} }
} }
@ -298,12 +302,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
} }
// launch kernel // launch kernel
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });
} else { } else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });
@ -333,12 +337,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
} }
// launch kernel // launch kernel
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });
} else { } else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });

View File

@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
const int nrows_y, const float scale, const float max_bias, const float m0, const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
const size_t n_local_scratch, queue_ptr stream) { const size_t n_local_scratch, queue_ptr stream) {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh); sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par, soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0, nrows_y, scale, max_bias, m0,

View File

@ -45,13 +45,8 @@ static void timestep_embedding_f32_sycl(
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE; int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE); sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
sycl::range<3> gridDim(1, ne00, num_blocks); sycl::range<3> gridDim(1, ne00, num_blocks);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>( timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
gridDim * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
timestep_embedding_f32(
x, dst, nb1, dim, max_period, item_ct1
);
}); });
} }

View File

@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel // Submit kernel
if (C / H == WKV_BLOCK_SIZE) { if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>( rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
}); });
}); });
} else { } else {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>( rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel // Submit kernel
if (C / H == WKV_BLOCK_SIZE) { if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>( rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
}); });
}); });
} else { } else {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>( rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()

View File

@ -1041,6 +1041,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
struct vk_instance_t { struct vk_instance_t {
vk::Instance instance; vk::Instance instance;
bool debug_utils_support = false; // VK_EXT_debug_utils enabled
PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
std::vector<size_t> device_indices; std::vector<size_t> device_indices;
vk_device devices[GGML_VK_MAX_DEVICES]; vk_device devices[GGML_VK_MAX_DEVICES];
}; };
@ -1180,6 +1188,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
} }
pipeline->compiled = true; pipeline->compiled = true;
if (vk_instance.debug_utils_support) {
vk::DebugUtilsObjectNameInfoEXT duoni;
duoni.objectType = vk::ObjectType::ePipeline;
duoni.pObjectName = pipeline->name.c_str();
duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
}
{ {
std::lock_guard<std::mutex> guard(device->mutex); std::lock_guard<std::mutex> guard(device->mutex);
device->pipelines.insert({ pipeline->name, pipeline }); device->pipelines.insert({ pipeline->name, pipeline });
@ -3561,6 +3577,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions); static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
static void ggml_vk_instance_init() { static void ggml_vk_instance_init() {
if (vk_instance_initialized) { if (vk_instance_initialized) {
return; return;
@ -3581,7 +3599,7 @@ static void ggml_vk_instance_init() {
#ifdef __APPLE__ #ifdef __APPLE__
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
#endif #endif
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
std::vector<const char*> layers; std::vector<const char*> layers;
if (validation_ext) { if (validation_ext) {
@ -3596,6 +3614,9 @@ static void ggml_vk_instance_init() {
extensions.push_back("VK_KHR_portability_enumeration"); extensions.push_back("VK_KHR_portability_enumeration");
} }
#endif #endif
if (debug_utils_ext) {
extensions.push_back("VK_EXT_debug_utils");
}
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
#ifdef __APPLE__ #ifdef __APPLE__
if (portability_enumeration_ext) { if (portability_enumeration_ext) {
@ -3619,6 +3640,18 @@ static void ggml_vk_instance_init() {
vk_instance.instance = vk::createInstance(instance_create_info); vk_instance.instance = vk::createInstance(instance_create_info);
vk_instance_initialized = true; vk_instance_initialized = true;
if (debug_utils_ext) {
vk_instance.debug_utils_support = true;
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
}
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@ -9495,6 +9528,12 @@ static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer
UNUSED(buft); UNUSED(buft);
} }
static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
return vk_instance.devices[0]->suballocation_block_size;
UNUSED(buft);
}
// Should be changed to return device-specific host buffer type // Should be changed to return device-specific host buffer type
// but that probably requires changes in llama.cpp // but that probably requires changes in llama.cpp
ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
@ -9503,7 +9542,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
/* .get_name = */ ggml_backend_vk_host_buffer_type_name, /* .get_name = */ ggml_backend_vk_host_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
}, },
@ -9650,6 +9689,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
if (vk_instance.debug_utils_support) {
vk::DebugUtilsLabelEXT dul = {};
dul.pLabelName = "ggml_backend_vk_graph_compute";
dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
}
uint64_t total_mat_mul_bytes = 0; uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
@ -10339,6 +10385,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
UNUSED(instance_extensions); UNUSED(instance_extensions);
} }
// Extension availability
static bool ggml_vk_instance_debug_utils_ext_available(
const std::vector<vk::ExtensionProperties> & instance_extensions) {
// Check for portability enumeration extension for MoltenVK support
for (const auto & properties : instance_extensions) {
if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
return true;
}
}
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
return false;
UNUSED(instance_extensions);
}
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
switch (props.vendorID) { switch (props.vendorID) {
case VK_VENDOR_ID_INTEL: case VK_VENDOR_ID_INTEL:

View File

@ -955,6 +955,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"UPSCALE", "UPSCALE",
"PAD", "PAD",
"PAD_REFLECT_1D", "PAD_REFLECT_1D",
"ROLL",
"ARANGE", "ARANGE",
"TIMESTEP_EMBEDDING", "TIMESTEP_EMBEDDING",
"ARGSORT", "ARGSORT",
@ -985,7 +986,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW", "OPT_STEP_ADAMW",
}; };
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"upscale(x)", "upscale(x)",
"pad(x)", "pad(x)",
"pad_reflect_1d(x)", "pad_reflect_1d(x)",
"roll(x)",
"arange(start, stop, step)", "arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)", "timestep_embedding(timesteps, dim, max_period)",
"argsort(x)", "argsort(x)",
@ -1080,7 +1082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)", "adamw(x)",
}; };
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4341,6 +4343,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
return result; return result;
} }
// ggml_roll
struct ggml_tensor * ggml_roll(
struct ggml_context * ctx,
struct ggml_tensor * a,
int shift0,
int shift1,
int shift2,
int shift3) {
GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
GGML_ASSERT(abs(shift0) < a->ne[0]);
GGML_ASSERT(abs(shift1) < a->ne[1]);
GGML_ASSERT(abs(shift2) < a->ne[2]);
GGML_ASSERT(abs(shift3) < a->ne[3]);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params_i32(result, 0, shift0);
ggml_set_op_params_i32(result, 1, shift1);
ggml_set_op_params_i32(result, 2, shift2);
ggml_set_op_params_i32(result, 3, shift3);
result->op = GGML_OP_ROLL;
result->src[0] = a;
return result;
}
// ggml_arange // ggml_arange
struct ggml_tensor * ggml_arange( struct ggml_tensor * ggml_arange(

View File

@ -199,6 +199,7 @@ class Keys:
MASK_ID = "tokenizer.ggml.mask_token_id" MASK_ID = "tokenizer.ggml.mask_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token" ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token" ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_SEP = "tokenizer.ggml.add_sep_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix" ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"

View File

@ -894,6 +894,9 @@ class GGUFWriter:
def add_add_eos_token(self, value: bool) -> None: def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value) self.add_bool(Keys.Tokenizer.ADD_EOS, value)
def add_add_sep_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
def add_add_space_prefix(self, value: bool) -> None: def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)

View File

@ -7,7 +7,10 @@ import os
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
from sentencepiece import SentencePieceProcessor try:
from sentencepiece import SentencePieceProcessor
except ImportError:
SentencePieceProcessor = None
import gguf import gguf
@ -116,6 +119,7 @@ class SpecialVocab:
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
def _try_load_from_tokenizer_json(self, path: Path) -> bool: def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer = None
tokenizer_file = path / 'tokenizer.json' tokenizer_file = path / 'tokenizer.json'
if tokenizer_file.is_file(): if tokenizer_file.is_file():
with open(tokenizer_file, encoding = 'utf-8') as f: with open(tokenizer_file, encoding = 'utf-8') as f:
@ -149,11 +153,97 @@ class SpecialVocab:
added_tokens = tokenizer.get('added_tokens', {}) added_tokens = tokenizer.get('added_tokens', {})
else: else:
added_tokens = {} added_tokens = {}
tokenizer_config = None
tokenizer_config_file = path / 'tokenizer_config.json' tokenizer_config_file = path / 'tokenizer_config.json'
if not tokenizer_config_file.is_file(): if tokenizer_config_file.is_file():
return True
with open(tokenizer_config_file, encoding = 'utf-8') as f: with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
if tokenizer:
special_bos = (tokenizer_config or {}).get('bos_token')
special_cls = (tokenizer_config or {}).get('cls_token')
special_eos = (tokenizer_config or {}).get('eos_token')
special_sep = (tokenizer_config or {}).get('sep_token')
if not special_bos and special_cls and tokenizer_config:
tokenizer_config['bos_token'] = special_bos = special_cls
if not special_eos and special_sep and tokenizer_config:
tokenizer_config['eos_token'] = special_eos = special_sep
if post_processor := tokenizer.get('post_processor'):
for processor in post_processor.get('processors', [post_processor]):
if processor.get('type') == 'RobertaProcessing':
self.add_special_token['bos'] = True
self.add_special_token['eos'] = True
self.add_special_token['sep'] = True
if not special_cls and tokenizer_config:
special_cls = processor.get('cls', [special_bos])[0]
tokenizer_config['cls_token'] = special_cls
if not special_sep and tokenizer_config:
special_sep = processor.get('sep', [special_eos])[0]
tokenizer_config['sep_token'] = special_sep
continue
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
# Only works with simple templates, **will** get it wrong on unusual sequences
if processor.get('type') == 'TemplateProcessing':
tmpl_single = processor.get('single', [])
tmpl_pair = processor.get('pair', [])
special_first = None
special_last = None
if len(tmpl_single) > 1:
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_bos = special_first
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
if special_first not in (special_bos, special_cls):
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_eos = special_last
elif special_last != special_eos:
if 'eot' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eot', )
tokenizer_config['eot_token'] = special_eos
elif 'eom' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eom', )
tokenizer_config['eom_token'] = special_eos
else:
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
tokenizer_config['eos_token'] = special_eos = special_last
self.add_special_token['eos'] = True if special_last == special_eos else False
if special_last != special_eos:
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
if tmpl_pair:
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
if tmpl_a != 'A' or tmpl_b != 'B':
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
# A [sep] [eos] B
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
add_sep = False
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos) and not special_last:
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
if len(tmpl_pair) == 2:
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos):
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
self.add_special_token['sep'] = add_sep
if add_sep and not special_sep and tokenizer_config:
tokenizer_config['sep_token'] = special_eos
continue
if not tokenizer_config:
return True
chat_template_alt = None chat_template_alt = None
chat_template_file = path / 'chat_template.json' chat_template_file = path / 'chat_template.json'
if chat_template_file.is_file(): if chat_template_file.is_file():
@ -302,6 +392,9 @@ class SentencePieceVocab(Vocab):
name = "spm" name = "spm"
def __init__(self, base_path: Path): def __init__(self, base_path: Path):
if SentencePieceProcessor is None:
raise RuntimeError("sentencepiece is not installed")
added_tokens: dict[str, int] = {} added_tokens: dict[str, int] = {}
if (fname_tokenizer := base_path / 'tokenizer.model').exists(): if (fname_tokenizer := base_path / 'tokenizer.model').exists():
# normal location # normal location

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.17.0" version = "0.17.1"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [
@ -22,7 +22,7 @@ python = ">=3.8"
numpy = ">=1.17" numpy = ">=1.17"
tqdm = ">=4.27" tqdm = ">=4.27"
pyyaml = ">=5.1" pyyaml = ">=5.1"
sentencepiece = ">=0.1.98,<=0.2.0" sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true }
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true } PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]

View File

@ -390,6 +390,7 @@ extern "C" {
void * imatrix; // pointer to importance matrix data void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
} llama_model_quantize_params; } llama_model_quantize_params;
typedef struct llama_logit_bias { typedef struct llama_logit_bias {
@ -943,12 +944,14 @@ extern "C" {
// Requires the context to have a memory. // Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder. // For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the memory state is restored to the state before this call // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
// Upon other return values, the memory state is restored to the state before this call
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted // 2 - aborted (processed ubatches will remain in the context's memory)
// -1 - invalid input batch // -1 - invalid input batch
// < -1 - error // < -1 - fatal error (processed ubatches will remain in the context's memory)
LLAMA_API int32_t llama_decode( LLAMA_API int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
@ -1044,6 +1047,7 @@ extern "C" {
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
@ -1087,6 +1091,7 @@ extern "C" {
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_tokens_max /// @return Returns the number of tokens on success, no more than n_tokens_max
/// @return Returns a negative number on failure - the number of tokens that would have been returned /// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so. /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// as plaintext. Does not insert a leading space. /// as plaintext. Does not insert a leading space.

View File

@ -1 +1 @@
8cda0a3c19f2c7dc493887353c42f6956bc268b1 9e4bee1c5afc2d677a5b32ecb90cbdb483e81fff

View File

@ -200,6 +200,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },

View File

@ -196,6 +196,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_MASK_ID, LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX, LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,

File diff suppressed because it is too large Load Diff

View File

@ -2,86 +2,44 @@
#include "llama.h" #include "llama.h"
#include "llama-cparams.h"
#include <array> #include <array>
#include <vector> #include <vector>
#include <set> #include <set>
#include <bitset>
#include <unordered_map>
// very similar to llama_batch, // keep this struct lightweight
// but has more metadata about sequences // it points to data in `llama_batch_allocr`
struct llama_ubatch { struct llama_ubatch {
bool equal_seqs; bool equal_seqs;
// TODO: whole_seqs for embeddings? // TODO: whole_seqs for embeddings?
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
llama_token * token; // [n_tokens] // seq_id_unq: unique sequence ids in the ubatch
float * embd; // [n_embd, n_tokens] // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
llama_pos * pos; // [n_tokens] // used for extracting sequence pooled embeddings
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs] // // size | idx | val
int8_t * output; // [n_tokens] llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -
}; };
struct llama_sbatch_seq { // a helper for sanitizing, fulfilling and splitting a batch
int32_t n_seq_id;
llama_seq_id * seq_id;
size_t offset;
size_t length;
};
// sequence-length-aware batch splitting
struct llama_sbatch {
// tokens left in this batch
size_t n_tokens;
size_t n_embd;
// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;
// buffers for the ubatches
// TODO: very hacky, this needs a complete rework
struct ubatch_data {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;
};
std::vector<ubatch_data> udatas;
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
// simple split, unknown number of sequences of unequal lengths
llama_ubatch split_simple(size_t n_ubatch);
// make batches of equal-length sequences
llama_ubatch split_equal(size_t n_ubatch);
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};
// a helper for sanitizing and fulfilling a batch
class llama_batch_allocr { class llama_batch_allocr {
public: public:
llama_batch_allocr(); llama_batch_allocr(uint32_t n_pos_per_embd);
// sanitize and auto-gen missing data in the input batch // sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
@ -89,20 +47,57 @@ public:
const llama_batch & batch_inp, const llama_batch & batch_inp,
const llama_vocab & vocab, const llama_vocab & vocab,
const llama_memory_i * memory, const llama_memory_i * memory,
bool embd_all); uint32_t n_embd,
bool output_all);
const llama_batch & get_batch() const; const llama_batch & get_batch() const;
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const; uint32_t get_n_outputs() const;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
// min/max positions of each sequence in the current ubatch
llama_pos seq_pos_min(llama_seq_id seq_id) const; llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const; llama_pos seq_pos_max(llama_seq_id seq_id) const;
// call once before splitting the batch to reset the internal state
void split_reset();
// simple split, unknown number of sequence sets of unequal lengths
llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
// a helper method for creating a well-defined ubatch of tokens
// TODO: support embeddings if needed in the future
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
private: private:
void clear(); void clear();
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
// for debugging, start with LLAMA_BATCH_DEBUG=2
void ubatch_print(const llama_ubatch & ubatch, int debug);
llama_batch batch; llama_batch batch;
// only for debugging purposes
const llama_vocab * vocab;
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
const uint32_t n_pos_per_embd;
uint32_t n_embd;
uint32_t n_outputs; uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@ -110,10 +105,43 @@ private:
std::vector<llama_pos> pos; std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id; std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id; std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output; std::vector<int8_t> output;
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s using pos_set_t = std::set<llama_pos>;
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 using seq_cpl_t = std::vector<bool>;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
using idx_vec_t = std::vector<int32_t>;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
// batch indices of the output
std::vector<int32_t> out_ids;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;
// llama_ubatch points to this data:
struct ubatch {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};
// current splitting state:
std::vector<ubatch> ubatches;
int debug; int debug;
}; };

View File

@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token // this template requires the model to have "\n\n" as EOT token
for (auto message : chat) { for (size_t i = 0; i < chat.size(); i++) {
std::string role(message->role); std::string role(chat[i]->role);
if (role == "user") { if (role == "system") {
ss << "User: " << message->content << "\n\nAssistant:"; ss << "System: " << trim(chat[i]->content) << "\n\n";
} else { } else if (role == "user") {
ss << message->content << "\n\n"; ss << "User: " << trim(chat[i]->content) << "\n\n";
if (i == chat.size() - 1) {
ss << "Assistant:";
}
} else if (role == "assistant") {
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
} }
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {

View File

@ -20,7 +20,7 @@ llama_context::llama_context(
const llama_model & model, const llama_model & model,
llama_context_params params) : llama_context_params params) :
model(model), model(model),
batch_allocr(std::make_unique<llama_batch_allocr>()) { balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us; t_start_us = model.t_start_us;
@ -280,8 +280,8 @@ llama_context::llama_context(
// simulate full KV cache // simulate full KV cache
const auto mstate = memory->init_full(); const auto mctx = memory->init_full();
if (!mstate) { if (!mctx) {
throw std::runtime_error("failed to initialize KV cache"); throw std::runtime_error("failed to initialize KV cache");
} }
@ -289,7 +289,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once // reserve pp graph first so that buffers are only allocated once
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -300,7 +300,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes // reserve with tg graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(1, 1, 1, mstate.get()); auto * gf = graph_reserve(1, 1, 1, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers"); throw std::runtime_error("failed to allocate compute tg buffers");
} }
@ -311,7 +311,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
optimize |= memory_force_optimize; optimize |= memory_force_optimize;
memory_force_optimize = false; memory_force_optimize = false;
const auto mstate = memory->init_update(this, optimize); const auto mctx = memory->init_update(this, optimize);
switch (mstate->get_status()) { switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
// noop // noop
@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
} }
} }
if (!mstate->apply()) { if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
} }
} }
// if the memory module did any computation, we have to reserve a new worst-case graph // if the memory module did any computation, we have to reserve a new worst-case graph
{ {
const auto mstate = memory->init_full(); const auto mctx = memory->init_full();
if (!mstate) { if (!mctx) {
throw std::runtime_error("failed to initialize memory state"); throw std::runtime_error("failed to initialize memory context");
} }
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
} }
@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end); return cvec.apply(model, data, len, n_embd, il_start, il_end);
} }
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mstate && !mstate->apply()) { if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED; ret = GGML_STATUS_FAILED;
return nullptr; return nullptr;
} }
@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
return nullptr; return nullptr;
} }
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) { if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED; ret = GGML_STATUS_FAILED;
@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
} }
int llama_context::encode(const llama_batch & batch_inp) { int llama_context::encode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (batch_inp.n_tokens == 0) { if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1; return -1;
} }
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
// note: during encode, we always pass the full sequence starting from pos = 0 // note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) { if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1; return -1;
} }
const llama_batch & batch = batch_allocr->get_batch(); const uint32_t n_tokens = balloc->get_n_tokens();
const uint32_t n_tokens = batch.n_tokens; const llama_ubatch ubatch = balloc->split_simple(n_tokens);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
n_queued_tokens += n_tokens; n_queued_tokens += n_tokens;
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
// reserve output buffer // reserve output buffer
if (output_reserve(n_tokens) < n_tokens) { if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
{ {
// extract sequence embeddings // extract sequence embeddings
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd); embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
// extract the rerank score - n_cls_out floats per sequence // extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out; const uint32_t n_cls_out = hparams.n_cls_out;
// TODO: fix indexing [UBATCH_IDX] for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const int32_t seq_idx = ubatch.seq_idx[seq_id];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_cls_out); embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
cross.v_embd.resize(cross.n_embd*cross.n_enc); cross.v_embd.resize(cross.n_embd*cross.n_enc);
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
const auto & batch = balloc->get_batch();
// remember the sequence ids used during the encoding - needed for cross attention later // remember the sequence ids used during the encoding - needed for cross attention later
cross.seq_ids_enc.resize(n_tokens); cross.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear(); cross.seq_ids_enc[i].clear();
for (int s = 0; s < batch.n_seq_id[i]; s++) { for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s]; const llama_seq_id seq_id = batch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id); cross.seq_ids_enc[i].insert(seq_id);
} }
} }
@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
} }
int llama_context::decode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (!memory) { if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(batch_inp); return encode(batch_inp);
@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1; return -1;
} }
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const llama_batch & batch = batch_allocr->get_batch();
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens(); const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
const uint32_t n_tokens_all = batch.n_tokens; // when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const uint32_t n_outputs_all = batch_allocr->get_n_outputs(); const uint32_t n_tokens_all = balloc->get_n_tokens();
const uint32_t n_outputs_all = balloc->get_n_outputs();
if (embd_all) { if (output_all) {
// require that all tokens are output // require that all tokens are output
if (n_outputs_all != n_tokens_all) { if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@ -942,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
// handle any pending defrags/shifts // handle any pending defrags/shifts
kv_self_update(false); kv_self_update(false);
llama_memory_state_ptr mstate; llama_memory_context_ptr mctx;
while (true) { while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all); mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mstate) { if (!mctx) {
return -2; return -2;
} }
switch (mstate->get_status()) { switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
} break; } break;
case LLAMA_MEMORY_STATUS_NO_UPDATE: case LLAMA_MEMORY_STATUS_NO_UPDATE:
{ {
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
return -2; return -2;
} }
@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
did_optimize = true; did_optimize = true;
if (kv_self_update(true)) { if (kv_self_update(true)) {
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
continue; continue;
} }
} }
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
return 1; return 1;
} }
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{ {
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
return -2; return -2;
} }
@ -996,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
do { do {
const auto & ubatch = mstate->get_ubatch(); const auto & ubatch = mctx->get_ubatch();
// count the outputs in this ubatch // count the outputs in this ubatch
{ {
@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs_all == n_tokens_all) { if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens; n_outputs_new = ubatch.n_tokens;
} else { } else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) { for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0); n_outputs_new += (int32_t) (ubatch.output[i] != 0);
} }
@ -1019,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status; ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@ -1028,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
pos_min[s] = std::numeric_limits<llama_pos>::max(); pos_min[s] = std::numeric_limits<llama_pos>::max();
} }
// TODO: fix sequence indexing
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0]; const auto & seq_id = ubatch.seq_id[i][0];
@ -1105,27 +1094,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
// extract sequence embeddings (cleared before processing each batch) // extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const llama_seq_id seq_id = ubatch.seq_id_unq[s];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { const int32_t seq_idx = ubatch.seq_idx[seq_id];
continue;
}
embd_seq_out[seq_id].resize(n_embd); embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
// extract the rerank score - a single float per sequence // extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { const uint32_t n_cls_out = hparams.n_cls_out;
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
continue; const llama_seq_id seq_id = ubatch.seq_id_unq[s];
} const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
@ -1136,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
} }
n_outputs_prev += n_outputs; n_outputs_prev += n_outputs;
} while (mstate->next()); } while (mctx->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith // set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all; n_outputs = n_outputs_all;
@ -1145,7 +1134,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs > 0) { if (n_outputs > 0) {
bool sorted_output = true; bool sorted_output = true;
auto & out_ids = mstate->out_ids(); auto & out_ids = balloc->get_out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs); GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
@ -1302,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
} }
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
if (n_tokens % n_seqs != 0) { if (n_tokens % n_seqs != 0) {
@ -1318,11 +1307,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
this->n_outputs = n_outputs; this->n_outputs = n_outputs;
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
this->n_outputs = save_n_outputs; this->n_outputs = save_n_outputs;
@ -1347,7 +1336,7 @@ llm_graph_result_ptr llama_context::graph_build(
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype, llm_graph_type gtype,
const llama_memory_state_i * mstate) { const llama_memory_context_i * mctx) {
return model.build_graph( return model.build_graph(
{ {
/*.ctx =*/ ctx, /*.ctx =*/ ctx,
@ -1359,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu, /*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec, /*.cvec =*/ &cvec,
/*.loras =*/ &loras, /*.loras =*/ &loras,
/*.mstate =*/ mstate, /*.mctx =*/ mctx,
/*.cross =*/ &cross, /*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs, /*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(), /*.cb =*/ graph_get_cb(),
@ -2039,7 +2028,12 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true; batch.logits [pos_batch] = true;
} }
const auto n_tokens_all = batch.n_tokens; if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}
const uint32_t n_tokens_all = balloc->get_n_tokens();
n_queued_tokens += n_tokens_all; n_queued_tokens += n_tokens_all;
@ -2047,8 +2041,8 @@ void llama_context::opt_epoch_iter(
uint32_t n_outputs_all = n_tokens_all; uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true); auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break; break;
} }
@ -2061,17 +2055,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0; uint32_t pos_batch = 0;
do { do {
const auto & ubatch = mstate->get_ubatch(); const auto & ubatch = mctx->get_ubatch();
n_outputs = ubatch.n_tokens; n_outputs = ubatch.n_tokens;
if (!mstate->apply()) { if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
break; break;
} }
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
struct ggml_context * ctx_compute_opt; struct ggml_context * ctx_compute_opt;
{ {
@ -2106,7 +2100,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt); ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens; pos_batch += ubatch.n_tokens;
} while (mstate->next()); } while (mctx->next());
} }
} }

View File

@ -18,7 +18,7 @@ class llama_io_read_i;
class llama_io_write_i; class llama_io_write_i;
struct llama_memory_i; struct llama_memory_i;
struct llama_memory_state_i; struct llama_memory_context_i;
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs // init scheduler and compute buffers, reserve worst-case graphs
@ -93,13 +93,13 @@ struct llama_context {
int32_t il_end); int32_t il_end);
// process a single ubatch with a specific graph type // process a single ubatch with a specific graph type
// if memory_state is provided, it will be applied first to the context's memory // if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation // ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS // returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch( llm_graph_result_ptr process_ubatch(
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype, llm_graph_type gtype,
llama_memory_state_i * mstate, llama_memory_context_i * mctx,
ggml_status & ret); ggml_status & ret);
int encode(const llama_batch & batch_inp); int encode(const llama_batch & batch_inp);
@ -197,7 +197,7 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched); ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size // reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
private: private:
llm_graph_result_ptr graph_build( llm_graph_result_ptr graph_build(
@ -205,7 +205,7 @@ private:
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype, llm_graph_type gtype,
const llama_memory_state_i * mstate); const llama_memory_context_i * mctx);
llm_graph_cb graph_get_cb() const; llm_graph_cb graph_get_cb() const;
@ -247,7 +247,7 @@ private:
std::map<llama_seq_id, std::vector<float>> embd_seq; std::map<llama_seq_id, std::vector<float>> embd_seq;
// reuse the batch_allocr to avoid unnecessary memory allocations // reuse the batch_allocr to avoid unnecessary memory allocations
std::unique_ptr<llama_batch_allocr> batch_allocr; std::unique_ptr<llama_batch_allocr> balloc;
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch

View File

@ -87,17 +87,13 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) { if (pos_bucket) {
kv_state->set_input_pos_bucket(pos_bucket, ubatch); mctx->set_input_pos_bucket(pos_bucket, ubatch);
} }
} }
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) { void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(out_ids);
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
if (!out_ids) {
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
} else {
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
@ -107,133 +103,116 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
data[i] = i; data[i] = i;
} }
} else if (ubatch->output) {
int32_t n_outputs = 0; return;
}
GGML_ASSERT(ubatch->output);
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
if (ubatch->output[i]) { if (ubatch->output[i]) {
data[n_outputs++] = i; data[n_outputs++] = i;
} }
} }
// the graph needs to have been passed the correct number of outputs
GGML_ASSERT(n_outputs == n_outputs);
} else if (n_outputs == 1) {
// only keep last output
data[0] = n_tokens - 1;
} else {
GGML_ASSERT(n_outputs == 0);
}
}
}
} }
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs; const int64_t n_seqs_unq = ubatch->n_seqs_unq;
GGML_ASSERT(mean); GGML_ASSERT(mean);
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
float * data = (float *) mean->data; float * data = (float *) mean->data;
memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean)); memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
std::vector<uint64_t> sum(n_tokens, 0); std::vector<uint64_t> sums(n_seqs_unq, 0);
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
// TODO: fix indexing [UBATCH_IDX] sums[seq_idx] += ubatch->n_seq_tokens;
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
sum[seq_id] += ubatch->n_seq_tokens;
}
std::vector<float> div(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const uint64_t s = sum[i];
if (s > 0) {
div[i] = 1.0f/float(s);
} }
} }
// TODO: fix indexing [UBATCH_IDX] std::vector<float> div(n_seqs_unq, 0.0f);
for (int s = 0; s < n_seqs; ++s) { for (int s = 0; s < n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0]; const uint64_t sum = sums[s];
if (sum > 0) {
div[s] = 1.0f/float(sum);
}
}
for (int i = 0; i < n_seq_tokens; ++i) { for (int i = 0; i < n_tokens; i += n_seq_tokens) {
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
for (int j = 0; j < n_seq_tokens; ++j) {
data[seq_idx*n_tokens + i + j] = div[seq_idx];
}
} }
} }
} }
} }
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs; const int64_t n_seqs_unq = ubatch->n_seqs_unq;
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
)) {
GGML_ASSERT(cls); GGML_ASSERT(cls);
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
uint32_t * data = (uint32_t *) cls->data; uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls)); memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
// TODO: fix indexing [UBATCH_IDX] for (int i = 0; i < n_tokens; i += n_seq_tokens) {
for (int s = 0; s < n_seqs; ++s) { for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0]; const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true data[seq_idx] = i;
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
}
} }
} }
} }
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(cls); GGML_ASSERT(cls);
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
uint32_t * data = (uint32_t *) cls->data; uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls)); memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
std::vector<int> last_pos(n_tokens, -1); std::vector<int> last_pos(n_seqs_unq, -1);
std::vector<int> last_row(n_tokens, -1); std::vector<int> last_row(n_seqs_unq, -1);
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
}
}
}
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) { const llama_pos pos = ubatch->pos[i];
data[i] = last_row[i];
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
if (pos >= last_pos[seq_idx]) {
last_pos[seq_idx] = pos;
last_row[seq_idx] = i;
}
}
}
for (int s = 0; s < n_seqs_unq; ++s) {
if (last_row[s] >= 0) {
data[s] = last_row[s];
} }
} }
} }
@ -242,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch); GGML_UNUSED(ubatch);
const int64_t n_rs = mem_state->get_n_rs(); const int64_t n_rs = mctx->get_n_rs();
if (s_copy) { if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@ -250,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) { for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mem_state->s_copy(i); data[i] = mctx->s_copy(i);
} }
} }
} }
@ -266,33 +245,28 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
if (kq_mask) {
if (cparams.causal_attn) {
const int64_t n_kv = ubatch->n_tokens; const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data; float * data = (float *) kq_mask->data;
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) { for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0]; const llama_seq_id s1 = ubatch->seq_id[i1][0];
for (int j = 0; j < n_seq_tokens; ++j) { for (int i0 = 0; i0 < n_tokens; ++i0) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY; float f = -INFINITY;
// TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { const llama_seq_id s0 = ubatch->seq_id[i0][0];
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
// TODO: reimplement this like in llama_kv_cache_unified
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
if (hparams.use_alibi) { if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else { } else {
f = 0.0f; f = 0.0f;
} }
@ -300,55 +274,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
} }
} }
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
}
}
}
}
}
} else {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_stride = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
} }
} }
} }
@ -356,22 +282,23 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { if (self_kq_mask) {
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
} }
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
if (self_kq_mask_swa) { if (self_kq_mask_swa) {
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
} }
} }
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
if (cross_kq_mask) { GGML_ASSERT(cross_kq_mask);
const int64_t n_enc = cross_kq_mask->ne[0]; const int64_t n_enc = cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
@ -381,17 +308,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
float * data = (float *) cross_kq_mask->data; float * data = (float *) cross_kq_mask->data;
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_tokens; ++i) {
for (int i = 0; i < n_enc; ++i) { for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY; float f = -INFINITY;
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[j][s]; const llama_seq_id seq_id = ubatch->seq_id[i][s];
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f; f = 0.0f;
} }
} }
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
} }
} }
@ -401,15 +330,14 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
} }
} }
}
} }
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { if (self_kq_mask) {
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs(); const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (s_copy) { if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@ -417,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) { for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mem_state->get_state_recr()->s_copy(i); data[i] = mctx->get_recr()->s_copy(i);
} }
} }
} }
@ -461,16 +389,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
backend_cpu (params.backend_cpu), backend_cpu (params.backend_cpu),
cvec (params.cvec), cvec (params.cvec),
loras (params.loras), loras (params.loras),
mstate (params.mstate), mctx (params.mctx),
cross (params.cross), cross (params.cross),
cb_func (params.cb), cb_func (params.cb),
res (std::make_unique<llm_graph_result>()) { res (std::make_unique<llm_graph_result>()) {
} }
int64_t llm_graph_context::n_pos_per_embd() const {
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
}
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
if (cb_func) { if (cb_func) {
cb_func(ubatch, cur, name, il); cb_func(ubatch, cur, name, il);
@ -915,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
} }
ggml_tensor * llm_graph_context::build_inp_pos() const { ggml_tensor * llm_graph_context::build_inp_pos() const {
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd()); auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
auto & cur = inp->pos; auto & cur = inp->pos;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd()); cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
ggml_set_input(cur); ggml_set_input(cur);
res->add_input(std::move(inp)); res->add_input(std::move(inp));
@ -942,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
} }
ggml_tensor * llm_graph_context::build_inp_out_ids() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const {
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
// but this would make the graph topology depend on the number of output tokens, which can interere with
// features that require constant topology such as pipline parallelism
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
//if (n_outputs < n_tokens) {
// return nullptr;
//}
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs); auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
auto & cur = inp->out_ids; auto & cur = inp->out_ids;
@ -959,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
auto & cur = inp->mean; auto & cur = inp->mean;
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
ggml_set_input(cur); ggml_set_input(cur);
res->add_input(std::move(inp)); res->add_input(std::move(inp));
@ -972,7 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
auto & cur = inp->cls; auto & cur = inp->cls;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
ggml_set_input(cur); ggml_set_input(cur);
res->add_input(std::move(inp)); res->add_input(std::move(inp));
@ -1018,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
} }
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state); auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
const auto n_kv = kv_state->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
auto & cur = inp->pos_bucket; auto & cur = inp->pos_bucket;
@ -1050,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
} }
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state); auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
{ {
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv(); const auto n_kv = inp->mctx->get_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1); //cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1067,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
} }
{ {
const auto n_rs = mem_state->get_state_recr()->get_n_rs(); const auto n_rs = mctx_cur->get_recr()->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy); ggml_set_input(inp->s_copy);
@ -1251,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
} }
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state); auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
{ {
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = kv_state->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1); //cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1288,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il); ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -1338,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate); const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
const bool is_swa = hparams.is_swa(il); const bool is_swa = hparams.is_swa(il);
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
} }
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il); ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -1447,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn(); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il); ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -1480,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn(
} }
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state); auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
{ {
const auto n_kv = kv_state->get_base()->get_n_kv(); const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1); //cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1497,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{ {
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv(); const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@ -1545,11 +1477,11 @@ ggml_tensor * llm_graph_context::build_rs(
} }
llm_graph_input_rs * llm_graph_context::build_rs_inp() const { llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_rs>(kv_state); auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
const auto n_rs = kv_state->get_n_rs(); const auto n_rs = mctx_cur->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy); ggml_set_input(inp->s_copy);
@ -1564,7 +1496,7 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
const llm_graph_get_rows_fn & get_state_rows) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
} }
@ -1576,7 +1508,7 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
const llm_graph_get_rows_fn & get_state_rows) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr(); const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
} }
@ -1586,13 +1518,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * token_shift_all = kv_state->get_r_l(il); ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
ggml_tensor * token_shift = build_rs( ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all, inp, gf, token_shift_all,
@ -1607,19 +1539,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift, ggml_tensor * token_shift,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
const auto kv_head = kv_state->get_head(); const auto kv_head = mctx_cur->get_head();
return ggml_cpy( return ggml_cpy(
ctx0, ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il))) ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
); );
} }

View File

@ -17,12 +17,12 @@ struct ggml_tensor;
struct llama_ubatch; struct llama_ubatch;
struct llama_cparams; struct llama_cparams;
struct llama_memory_state_i; struct llama_memory_context_i;
class llama_kv_cache_unified_state; class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_state; class llama_kv_cache_unified_iswa_context;
class llama_memory_recurrent_state; class llama_memory_recurrent_context;
class llama_memory_hybrid_state; class llama_memory_hybrid_context;
// certain models (typically multi-modal) can produce different types of graphs // certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type { enum llm_graph_type {
@ -95,14 +95,14 @@ public:
class llm_graph_input_pos : public llm_graph_input_i { class llm_graph_input_pos : public llm_graph_input_i {
public: public:
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
virtual ~llm_graph_input_pos() = default; virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos = nullptr; // I32 [n_batch] ggml_tensor * pos = nullptr; // I32 [n_batch]
const int64_t n_pos_per_embd = 1; const uint32_t n_pos_per_embd = 1;
}; };
// temperature tuning, used by llama4 // temperature tuning, used by llama4
@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public: public:
llm_graph_input_pos_bucket_kv( llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default; virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
@ -144,7 +144,8 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_kv_cache_unified_state * kv_state;
const llama_kv_cache_unified_context * mctx;
}; };
class llm_graph_input_out_ids : public llm_graph_input_i { class llm_graph_input_out_ids : public llm_graph_input_i {
@ -191,14 +192,14 @@ public:
class llm_graph_input_rs : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i {
public: public:
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {} llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
virtual ~llm_graph_input_rs() = default; virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_copy; // I32 [kv_size]
const llama_memory_recurrent_state * mem_state; const llama_memory_recurrent_context * mctx;
}; };
class llm_graph_input_cross_embd : public llm_graph_input_i { class llm_graph_input_cross_embd : public llm_graph_input_i {
@ -238,10 +239,10 @@ public:
llm_graph_input_attn_kv_unified( llm_graph_input_attn_kv_unified(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_kv_cache_unified_state * kv_state) : const llama_kv_cache_unified_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
kv_state(kv_state) { mctx(mctx) {
} }
~llm_graph_input_attn_kv_unified() = default; ~llm_graph_input_attn_kv_unified() = default;
@ -255,7 +256,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_kv_cache_unified_state * kv_state; const llama_kv_cache_unified_context * mctx;
}; };
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@ -263,10 +264,10 @@ public:
llm_graph_input_attn_kv_unified_iswa( llm_graph_input_attn_kv_unified_iswa(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_state * kv_state) : const llama_kv_cache_unified_iswa_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
kv_state(kv_state) { mctx(mctx) {
} }
~llm_graph_input_attn_kv_unified_iswa() = default; ~llm_graph_input_attn_kv_unified_iswa() = default;
@ -283,7 +284,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_kv_cache_unified_iswa_state * kv_state; const llama_kv_cache_unified_iswa_context * mctx;
}; };
class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_attn_cross : public llm_graph_input_i {
@ -306,10 +307,10 @@ public:
llm_graph_input_mem_hybrid( llm_graph_input_mem_hybrid(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_memory_hybrid_state * mem_state) : const llama_memory_hybrid_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
mem_state(mem_state) { mctx(mctx) {
} }
virtual ~llm_graph_input_mem_hybrid() = default; virtual ~llm_graph_input_mem_hybrid() = default;
@ -325,7 +326,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_memory_hybrid_state * mem_state; const llama_memory_hybrid_context * mctx;
}; };
// //
@ -403,7 +404,7 @@ struct llm_graph_params {
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
const llama_memory_state_i * mstate; const llama_memory_context_i * mctx;
const llama_cross * cross; const llama_cross * cross;
uint32_t n_outputs; uint32_t n_outputs;
@ -458,7 +459,7 @@ struct llm_graph_context {
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
const llama_memory_state_i * mstate; const llama_memory_context_i * mctx;
const llama_cross * cross; const llama_cross * cross;
const llm_graph_cb & cb_func; const llm_graph_cb & cb_func;
@ -467,8 +468,6 @@ struct llm_graph_context {
llm_graph_context(const llm_graph_params & params); llm_graph_context(const llm_graph_params & params);
int64_t n_pos_per_embd() const;
void cb(ggml_tensor * cur, const char * name, int il) const; void cb(ggml_tensor * cur, const char * name, int il) const;
// //

View File

@ -91,6 +91,10 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
return recurrent_layer_arr[il]; return recurrent_layer_arr[il];
} }
uint32_t llama_hparams::n_pos_per_embd() const {
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
}
bool llama_hparams::is_swa(uint32_t il) const { bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) { if (il < n_layer) {
return swa_layers[il]; return swa_layers[il];

View File

@ -193,6 +193,8 @@ struct llama_hparams {
// whether or not the given layer is recurrent (for hybrid models) // whether or not the given layer is recurrent (for hybrid models)
bool is_recurrent(uint32_t il) const; bool is_recurrent(uint32_t il) const;
uint32_t n_pos_per_embd() const;
bool is_swa(uint32_t il) const; bool is_swa(uint32_t il) const;
}; };

View File

@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id); return kv_swa->seq_pos_max(seq_id);
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all); GGML_UNUSED(embd_all);
// first try simple split // first try simple split
do { do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true); balloc.split_reset();
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_simple(n_ubatch);
while (sbatch.n_tokens > 0) { if (ubatch.n_tokens == 0) {
auto ubatch = sbatch.split_simple(n_ubatch); break;
}
ubatches.push_back(ubatch); ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); auto heads_base = kv_base->prepare(ubatches);
@ -122,20 +125,23 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
assert(heads_base.size() == heads_swa.size()); assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false); } while (false);
// if it fails, try equal split // if it fails, try equal split
do { do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false); balloc.split_reset();
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch);
while (sbatch.n_tokens > 0) { if (ubatch.n_tokens == 0) {
auto ubatch = sbatch.split_equal(n_ubatch); break;
}
ubatches.push_back(ubatch); ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); auto heads_base = kv_base->prepare(ubatches);
@ -150,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
assert(heads_base.size() == heads_swa.size()); assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false); } while (false);
// TODO: if we fail again, we should attempt different splitting strategies // TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible // but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this); return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize); return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
} }
bool llama_kv_cache_unified_iswa::get_can_shift() const { bool llama_kv_cache_unified_iswa::get_can_shift() const {
@ -191,48 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
} }
// //
// llama_kv_cache_unified_iswa_state // llama_kv_cache_unified_iswa_context
// //
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) : llama_kv_cache_unified_iswa * kv) :
state_base(kv->get_base()->init_full()), ctx_base(kv->get_base()->init_full()),
state_swa (kv->get_swa ()->init_full()), ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_context * lctx, llama_context * lctx,
bool optimize) : bool optimize) :
state_base(kv->get_base()->init_update(lctx, optimize)), ctx_base(kv->get_base()->init_update(lctx, optimize)),
state_swa (kv->get_swa ()->init_update(lctx, optimize)), ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base, std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa, std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)), ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)), ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
bool llama_kv_cache_unified_iswa_state::next() { bool llama_kv_cache_unified_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_base->next(); ctx_base->next();
state_swa ->next(); ctx_swa ->next();
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
return false; return false;
@ -241,41 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
return true; return true;
} }
bool llama_kv_cache_unified_iswa_state::apply() { bool llama_kv_cache_unified_iswa_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true; bool res = true;
res = res & state_base->apply(); res = res & ctx_base->apply();
res = res & state_swa ->apply(); res = res & ctx_swa ->apply();
return res; return res;
} }
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() { llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_base.get()); return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
} }
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get()); return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
} }

View File

@ -31,14 +31,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
const llama_batch & batch, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -72,62 +72,57 @@ private:
std::unique_ptr<llama_kv_cache_unified> kv_swa; std::unique_ptr<llama_kv_cache_unified> kv_swa;
}; };
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public: public:
// used for errors // used for errors
llama_kv_cache_unified_iswa_state(llama_memory_status status); llama_kv_cache_unified_iswa_context(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache context
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv); llama_kv_cache_unified_iswa * kv);
// used to create an update state // used to create an update context
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_context * lctx, llama_context * lctx,
bool optimize); bool optimize);
// used to create a state from a batch // used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base, std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa, std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_state(); virtual ~llama_kv_cache_unified_iswa_context();
// //
// llama_memory_state_i // llama_memory_context_i
// //
bool next() override; bool next() override;
bool apply() override; bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override; llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_kv_cache_unified_iswa_state specific API // llama_kv_cache_unified_iswa_context specific API
// //
const llama_kv_cache_unified_state * get_base() const; const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_state * get_swa() const; const llama_kv_cache_unified_context * get_swa() const;
private: private:
//llama_kv_cache_unified_iswa * kv; //llama_kv_cache_unified_iswa * kv;
llama_sbatch sbatch;
// the index of the next ubatch to process // the index of the next ubatch to process
size_t i_next = 0; size_t i_next = 0;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_base; const llama_memory_context_ptr ctx_base;
const llama_memory_state_ptr state_swa; const llama_memory_context_ptr ctx_swa;
const llama_memory_status status; const llama_memory_status status;
}; };

View File

@ -307,18 +307,24 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
return cells.seq_pos_max(seq_id); return cells.seq_pos_max(seq_id);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_batch( llama_memory_context_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) { bool embd_all) {
GGML_UNUSED(embd_all); GGML_UNUSED(embd_all);
do { do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true); balloc.split_reset();
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) { while (true) {
ubatches.push_back(sbatch.split_simple(n_ubatch)); auto ubatch = balloc.split_simple(n_ubatch);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads = prepare(ubatches); auto heads = prepare(ubatches);
@ -326,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
break; break;
} }
return std::make_unique<llama_kv_cache_unified_state>( return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(sbatch), std::move(heads), std::move(ubatches)); this, std::move(heads), std::move(ubatches));
} while (false); } while (false);
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_full() { llama_memory_context_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_state>(this); return std::make_unique<llama_kv_cache_unified_context>(this);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
bool do_shift = get_has_shift(); bool do_shift = get_has_shift();
defrag_info dinfo; defrag_info dinfo;
@ -367,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
} }
} }
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo)); return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
} }
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
} }
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
}
// keep track of the max sequence position that we would overwrite with this ubatch // keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty // for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1; seq_pos_max_rm[s] = -1;
} }
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) { if (!cells.is_empty(head_cur + i)) {
const uint32_t idx = s*ubatch.n_seq_tokens + j; assert(cells.seq_count(head_cur + i) == 1);
if (!cells.is_empty(head_cur + idx)) { const llama_seq_id seq_id = cells.seq_get(head_cur + i);
assert(cells.seq_count(head_cur + idx) == 1); const llama_pos pos = cells.pos_get(head_cur + i);
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
const llama_pos pos = cells.pos_get(head_cur + idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + idx); cells.rm(head_cur + i);
} }
cells.pos_set(head_cur + idx, ubatch.pos[idx]); cells.pos_set(head_cur + i, ubatch.pos[i]);
// TODO: fix indexing [UBATCH_IDX] for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
}
} }
} }
@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
} }
} }
// move the head at the end of the slot // move the head at the end of the slot
head = head_cur + ubatch.n_tokens; head = head_cur + ubatch.n_tokens;
} }
@ -793,8 +789,6 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens; const uint32_t n_tokens = ubatch->n_tokens;
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
const uint32_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data; float * data = (float *) dst->data;
@ -814,26 +808,23 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
// xxxxx----- // xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (uint32_t h = 0; h < 1; ++h) { for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = ubatch->seq_id[s][0]; const llama_seq_id seq_id = ubatch->seq_id[i][0];
for (uint32_t j = 0; j < n_seq_tokens; ++j) { const llama_pos p1 = ubatch->pos[i];
const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[idx]; for (uint32_t j = 0; j < n_kv; ++j) {
for (uint32_t i = 0; i < n_kv; ++i) {
float f = 0.0f; float f = 0.0f;
bool masked = false; bool masked = false;
if (cells.is_empty(i)) { if (cells.is_empty(j)) {
masked = true; masked = true;
} else { } else {
const llama_pos p0 = cells.pos_get(i); const llama_pos p0 = cells.pos_get(j);
// mask the token if not the same sequence // mask the token if not the same sequence
masked = masked || (!cells.seq_has(i, seq_id)); masked = masked || (!cells.seq_has(j, seq_id));
// mask future tokens // mask future tokens
masked = masked || (causal_attn && p0 > p1); masked = masked || (causal_attn && p0 > p1);
@ -850,16 +841,15 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
f = -INFINITY; f = -INFINITY;
} }
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f; data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
}
} }
} }
// mask padded tokens // mask padded tokens
if (data) { if (data) {
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) { for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (uint32_t i = 0; i < n_kv; ++i) { for (uint32_t j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
} }
} }
} }
@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
const int32_t n_kv = dst->ne[0]; const int32_t n_kv = dst->ne[0];
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_tokens; ++i) {
for (int i = 0; i < n_kv; ++i) { for (int j = 0; j < n_kv; ++j) {
// the position when the cells is empty is irrelevant - it will be masked out later in the attention // the position when the cells is empty is irrelevant - it will be masked out later in the attention
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
} }
} }
} }
@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
seq_rm(dest_seq_id, -1, -1); seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch; llama_batch_allocr balloc(hparams.n_pos_per_embd());
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
ubatch.n_tokens = cell_count; llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
ubatch.n_seq_tokens = cell_count;
ubatch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos; llama_pos pos;
@ -1723,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
} }
// //
// llama_kv_cache_unified_state // llama_kv_cache_unified_context
// //
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size(); n_kv = kv->get_size();
head = 0; head = 0;
} }
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_context * lctx, llama_context * lctx,
bool do_shift, bool do_shift,
@ -1744,16 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
} }
} }
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_sbatch sbatch,
llama_kv_cache_unified::ubatch_heads heads, llama_kv_cache_unified::ubatch_heads heads,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
} }
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_state::next() { bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
@ -1763,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
return true; return true;
} }
bool llama_kv_cache_unified_state::apply() { bool llama_kv_cache_unified_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
// no ubatches -> this is a KV cache update // no ubatches -> this is a KV cache update
@ -1781,51 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
return true; return true;
} }
std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() { llama_memory_status llama_kv_cache_unified_context::get_status() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_kv_cache_unified_state::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
uint32_t llama_kv_cache_unified_state::get_n_kv() const { uint32_t llama_kv_cache_unified_context::get_n_kv() const {
return n_kv; return n_kv;
} }
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv); return kv->get_k(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
return kv->get_v(ctx, il, n_kv); return kv->get_v(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head); return kv->cpy_k(ctx, k_cur, il, head);
} }
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head); return kv->cpy_v(ctx, v_cur, il, head);
} }
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst); kv->set_input_k_shift(dst);
} }
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn); kv->set_input_kq_mask(dst, ubatch, causal_attn);
} }
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch); kv->set_input_pos_bucket(dst, ubatch);
} }

View File

@ -56,14 +56,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
const llama_batch & batch, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -208,49 +208,46 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
class llama_kv_cache_unified_state : public llama_memory_state_i { class llama_kv_cache_unified_context : public llama_memory_context_i {
public: public:
// some shorthands // some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info; using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors // used for errors
llama_kv_cache_unified_state(llama_memory_status status); llama_kv_cache_unified_context(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache context
llama_kv_cache_unified_state( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv); llama_kv_cache_unified * kv);
// used to create an update state // used to create an update context
llama_kv_cache_unified_state( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_context * lctx, llama_context * lctx,
bool do_shift, bool do_shift,
defrag_info dinfo); defrag_info dinfo);
// used to create a decode state from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_state( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_sbatch sbatch,
ubatch_heads heads, ubatch_heads heads,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state(); virtual ~llama_kv_cache_unified_context();
// //
// llama_memory_state_i // llama_memory_context_i
// //
bool next() override; bool next() override;
bool apply() override; bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override; llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_kv_cache_unified_state specific API // llama_kv_cache_unified_context specific API
// //
uint32_t get_n_kv() const; uint32_t get_n_kv() const;
@ -275,7 +272,7 @@ private:
llama_context * lctx; llama_context * lctx;
// //
// update state // update context
// //
bool do_shift = false; bool do_shift = false;
@ -283,11 +280,9 @@ private:
defrag_info dinfo; defrag_info dinfo;
// //
// batch processing state // batch processing context
// //
llama_sbatch sbatch;
// the index of the next ubatch to process // the index of the next ubatch to process
size_t i_next = 0; size_t i_next = 0;

View File

@ -7,6 +7,7 @@
#include <cassert> #include <cassert>
#include <vector> #include <vector>
#include <set> #include <set>
#include <map>
// meta information about KV cells that can be part of multiple sequences at the same time // meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests // TODO: add unit tests
@ -164,7 +165,7 @@ public:
assert(seq_id >= 0); assert(seq_id >= 0);
seq[i].reset(seq_id); seq[i].reset(seq_id);
seq_pos[seq_id].erase(pos[i]); seq_pos_dec(seq_id, pos[i]);
if (seq[i].none()) { if (seq[i].none()) {
pos[i] = -1; pos[i] = -1;
@ -187,7 +188,7 @@ public:
seq[i].reset(); seq[i].reset();
seq[i].set(seq_id); seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]); seq_pos_inc(seq_id, pos[i]);
return false; return false;
} }
@ -232,7 +233,7 @@ public:
assert(!seq[i].test(seq_id)); assert(!seq[i].test(seq_id));
seq[i].set(seq_id); seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]); seq_pos_inc(seq_id, pos[i]);
} }
// return the sequence id of this cell // return the sequence id of this cell
@ -259,7 +260,9 @@ public:
return -1; return -1;
} }
return *seq_pos[seq_id].begin(); assert(seq_pos[seq_id].begin()->second > 0);
return seq_pos[seq_id].begin()->first;
} }
// the maximum position of sequence seq_id currently present in any of the cells // the maximum position of sequence seq_id currently present in any of the cells
@ -272,7 +275,9 @@ public:
return -1; return -1;
} }
return *seq_pos[seq_id].rbegin(); assert(seq_pos[seq_id].rbegin()->second > 0);
return seq_pos[seq_id].rbegin()->first;
} }
// note: call only if the cell is not empty // note: call only if the cell is not empty
@ -384,22 +389,41 @@ private:
// //
std::vector<llama_pos> shift; std::vector<llama_pos> shift;
using bits_t = std::bitset<LLAMA_MAX_SEQ>; using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<bits_t> seq; std::vector<seq_set_t> seq;
// the set seq_pos[s] tells us which positions are currently present for sequence s // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
// if the position p is not present, seq_pos[s][p] is not set
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ]; //
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
// - during performing a cache reuse via (rm + add)
// - some vision models have input embeddings with repeating positions
//
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time: // helper functions for updating `seq_pos`, once cell at a time:
void seq_pos_dec(llama_seq_id s, llama_pos p) {
auto it = seq_pos[s].find(p);
assert(it != seq_pos[s].end());
if (--it->second == 0) {
seq_pos[s].erase(it);
}
}
void seq_pos_inc(llama_seq_id s, llama_pos p) {
seq_pos[s][p]++;
}
// remove cell i // remove cell i
void seq_pos_rm(uint32_t i) { void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) { if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]); seq_pos_dec(s, pos[i]);
} }
} }
} }
@ -408,7 +432,7 @@ private:
void seq_pos_add(uint32_t i) { void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) { if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]); seq_pos_inc(s, pos[i]);
} }
} }
} }

View File

@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid(
mem_attn(new llama_kv_cache_unified( mem_attn(new llama_kv_cache_unified(
model, model,
filter_attn == nullptr ? filter_attn == nullptr ?
[&](int32_t il) { return !model.hparams.is_recurrent(il); } [&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn, : filter_attn,
type_k, type_k,
type_v, type_v,
@ -47,7 +47,7 @@ llama_memory_hybrid::llama_memory_hybrid(
mem_recr(new llama_memory_recurrent( mem_recr(new llama_memory_recurrent(
model, model,
filter_recr == nullptr ? filter_recr == nullptr ?
[&](int32_t il) { return model.hparams.is_recurrent(il); } [&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr, : filter_recr,
type_r, type_r,
type_s, type_s,
@ -56,50 +56,57 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max n_seq_max
)) {} )) {}
llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
// since this includes a recurrent cache, we cannot use split_simple balloc.split_reset();
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
// follow the recurrent pattern for creating the ubatch splits // follow the recurrent pattern for creating the ubatch splits
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
while (true) {
llama_ubatch ubatch; llama_ubatch ubatch;
if (embd_pooled) { if (embd_all) {
// Pooled embeddings cannot be split across ubatches (yet) // if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = sbatch.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch);
} }
ubatches.push_back(ubatch); if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
} }
// prepare the recurrent batches first // prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) { if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined state at this point? // TODO: will the recurrent cache be in an undefined context at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
// prepare the attention cache // prepare the attention cache
auto heads_attn = mem_attn->prepare(ubatches); auto heads_attn = mem_attn->prepare(ubatches);
if (heads_attn.empty()) { if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
return std::make_unique<llama_memory_hybrid_state>( return std::make_unique<llama_memory_hybrid_context>(
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); this, std::move(heads_attn), std::move(ubatches));
} while(false);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_memory_hybrid::init_full() { llama_memory_context_ptr llama_memory_hybrid::init_full() {
return std::make_unique<llama_memory_hybrid_state>(this); return std::make_unique<llama_memory_hybrid_context>(this);
} }
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize); return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
} }
bool llama_memory_hybrid::get_can_shift() const { bool llama_memory_hybrid::get_can_shift() const {
@ -169,41 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
return mem_recr.get(); return mem_recr.get();
} }
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {} llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) : llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
state_attn(mem->get_mem_attn()->init_full()), ctx_attn(mem->get_mem_attn()->init_full()),
state_recr(mem->get_mem_recr()->init_full()), ctx_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
llama_memory_hybrid_state::llama_memory_hybrid_state( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
llama_context * lctx, llama_context * lctx,
bool optimize) : bool optimize) :
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)), ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)), ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
llama_memory_hybrid_state::llama_memory_hybrid_state( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
llama_sbatch sbatch,
std::vector<uint32_t> heads_attn, std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)), ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(LLAMA_MEMORY_STATUS_SUCCESS) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
bool llama_memory_hybrid_state::next() { bool llama_memory_hybrid_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_attn->next(); ctx_attn->next();
state_recr->next(); ctx_recr->next();
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
return false; return false;
@ -212,36 +217,30 @@ bool llama_memory_hybrid_state::next() {
return true; return true;
} }
bool llama_memory_hybrid_state::apply() { bool llama_memory_hybrid_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true; bool res = true;
res = res & state_attn->apply(); res = res & ctx_attn->apply();
res = res & state_recr->apply(); res = res & ctx_recr->apply();
return res; return res;
} }
std::vector<int64_t> & llama_memory_hybrid_state::out_ids() { llama_memory_status llama_memory_hybrid_context::get_status() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_memory_hybrid_state::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const { const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const { const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get()); return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
} }
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const { const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
return static_cast<const llama_memory_recurrent_state *>(state_recr.get()); return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
} }

View File

@ -49,14 +49,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
const llama_batch & batch, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_pooled) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -90,54 +90,49 @@ private:
const std::unique_ptr<llama_memory_recurrent> mem_recr; const std::unique_ptr<llama_memory_recurrent> mem_recr;
}; };
class llama_memory_hybrid_state : public llama_memory_state_i { class llama_memory_hybrid_context : public llama_memory_context_i {
public: public:
// init failure // init failure
explicit llama_memory_hybrid_state(llama_memory_status status); explicit llama_memory_hybrid_context(llama_memory_status status);
// init full // init full
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem); explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
// init update // init update
explicit llama_memory_hybrid_state( explicit llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
llama_context * lctx, llama_context * lctx,
bool optimize); bool optimize);
// init success // init success
llama_memory_hybrid_state( llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
llama_sbatch sbatch,
std::vector<uint32_t> heads_attn, std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_state() = default; ~llama_memory_hybrid_context() = default;
bool next() override; bool next() override;
bool apply() override; bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override; llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_memory_hybrid_state // llama_memory_hybrid_context
// //
const llama_kv_cache_unified_state * get_state_attn() const; const llama_kv_cache_unified_context * get_attn() const;
const llama_memory_recurrent_state * get_state_recr() const; const llama_memory_recurrent_context * get_recr() const;
private: private:
llama_sbatch sbatch;
// the index of the next ubatch to process // the index of the next ubatch to process
size_t i_next = 0; size_t i_next = 0;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_attn; const llama_memory_context_ptr ctx_attn;
const llama_memory_state_ptr state_recr; const llama_memory_context_ptr ctx_recr;
const llama_memory_status status; const llama_memory_status status;
}; };

View File

@ -362,40 +362,42 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result; return result;
} }
llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) { while (true) {
llama_ubatch ubatch; llama_ubatch ubatch;
if (embd_all) { if (embd_all) {
// if all tokens are output, split by sequence // if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = sbatch.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch);
} }
ubatches.push_back(ubatch); if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
} }
if (!prepare(ubatches)) { if (!prepare(ubatches)) {
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches)); return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
} }
llama_memory_state_ptr llama_memory_recurrent::init_full() { llama_memory_context_ptr llama_memory_recurrent::init_full() {
return std::make_unique<llama_memory_recurrent_state>(this); return std::make_unique<llama_memory_recurrent_context>(this);
} }
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
GGML_UNUSED(lctx); GGML_UNUSED(lctx);
GGML_UNUSED(optimize); GGML_UNUSED(optimize);
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE); return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
} }
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) { bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@ -423,9 +425,8 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
} }
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens; const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
@ -445,9 +446,11 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// everything should fit if all seq_ids are smaller than the max // everything should fit if all seq_ids are smaller than the max
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t n_seq_id = ubatch.n_seq_id[s]; const uint32_t i = s*n_seq_tokens; // first token of sequence set s
const uint32_t n_seq_id = ubatch.n_seq_id[i];
for (uint32_t j = 0; j < n_seq_id; ++j) { for (uint32_t j = 0; j < n_seq_id; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j]; const llama_seq_id seq_id = ubatch.seq_id[i][j];
if (seq_id < 0 || (uint32_t) seq_id >= size) { if (seq_id < 0 || (uint32_t) seq_id >= size) {
// too big seq_id // too big seq_id
@ -506,7 +509,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// find usable cell range // find usable cell range
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const uint32_t i = s*n_seq_tokens;
const llama_seq_id seq_id = ubatch.seq_id[i][0];
auto & seq_meta = cells[seq_id]; auto & seq_meta = cells[seq_id];
bool has_cell = false; bool has_cell = false;
if (seq_meta.tail >= 0) { if (seq_meta.tail >= 0) {
@ -530,7 +534,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
seq_meta.tail = next_empty_cell; seq_meta.tail = next_empty_cell;
// find next empty cell // find next empty cell
if (s + 1 < n_seqs) { if (s + 1 < n_seqs) {
for (uint32_t i = 0; i < size; ++i) { for (uint32_t j = 0; j < size; ++j) {
next_empty_cell += 1; next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; } if (next_empty_cell >= size) { next_empty_cell -= size; }
auto & cell = cells[next_empty_cell]; auto & cell = cells[next_empty_cell];
@ -544,8 +548,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// gather and re-order // gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t i = s*n_seq_tokens;
const int32_t dst_id = s + min; const int32_t dst_id = s + min;
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail; const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
if (dst_id != src_id) { if (dst_id != src_id) {
auto & dst_cell = cells[dst_id]; auto & dst_cell = cells[dst_id];
auto & src_cell = cells[src_id]; auto & src_cell = cells[src_id];
@ -555,8 +560,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
std::swap(dst_cell.seq_id, src_cell.seq_id); std::swap(dst_cell.seq_id, src_cell.seq_id);
// swap tails // swap tails
for (uint32_t i = 0; i < size; ++i) { for (uint32_t j = 0; j < size; ++j) {
int32_t & tail = cells[i].tail; int32_t & tail = cells[j].tail;
if (tail == src_id) { if (tail == src_id) {
tail = dst_id; tail = dst_id;
} else if (tail == dst_id) { } else if (tail == dst_id) {
@ -568,7 +573,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// update the pos of the used seqs // update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; const uint32_t i = s*n_seq_tokens;
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
const int32_t cell_id = s + min; const int32_t cell_id = s + min;
auto & cell = cells[cell_id]; auto & cell = cells[cell_id];
@ -576,12 +582,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// What should happen when the pos backtracks or skips a value? // What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done. // Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
} }
cell.pos = last_pos; cell.pos = last_pos;
cell.seq_id.clear(); cell.seq_id.clear();
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j]; const llama_seq_id seq_id = ubatch.seq_id[i][j];
cell.seq_id.insert(seq_id); cell.seq_id.insert(seq_id);
cells[seq_id].tail = cell_id; cells[seq_id].tail = cell_id;
} }
@ -827,12 +833,9 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
seq_rm(dest_seq_id, -1, -1); seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch; llama_batch_allocr balloc(hparams.n_pos_per_embd());
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count; llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
batch.n_seq_tokens = cell_count;
batch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos; llama_pos pos;
@ -846,12 +849,12 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
return false; return false;
} }
batch.pos[i] = pos; ubatch.pos[i] = pos;
} }
batch.n_seq_id[0] = 1; ubatch.n_seq_id[0] = 1;
batch.seq_id[0] = &dest_seq_id; ubatch.seq_id[0] = &dest_seq_id;
if (!find_slot(batch)) { if (!find_slot(ubatch)) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
@ -859,8 +862,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells // Assume that this is one contiguous block of cells
GGML_ASSERT(head + cell_count <= size); GGML_ASSERT(head + cell_count <= size);
GGML_ASSERT(cells[head].pos == batch.pos[0]); GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
} else { } else {
@ -1037,23 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
} }
// //
// llama_memory_recurrent_state // llama_memory_recurrent_context
// //
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
llama_memory_recurrent_state::llama_memory_recurrent_state( llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
} }
llama_memory_recurrent_state::llama_memory_recurrent_state( llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem, llama_memory_recurrent * mem,
llama_sbatch sbatch, std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default; llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
bool llama_memory_recurrent_state::next() { bool llama_memory_recurrent_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
@ -1063,7 +1065,7 @@ bool llama_memory_recurrent_state::next() {
return true; return true;
} }
bool llama_memory_recurrent_state::apply() { bool llama_memory_recurrent_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
mem->find_slot(ubatches[i_next]); mem->find_slot(ubatches[i_next]);
@ -1071,46 +1073,40 @@ bool llama_memory_recurrent_state::apply() {
return true; return true;
} }
std::vector<int64_t> & llama_memory_recurrent_state::out_ids() { llama_memory_status llama_memory_recurrent_context::get_status() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_memory_recurrent_state::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const { const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
uint32_t llama_memory_recurrent_state::get_n_rs() const { uint32_t llama_memory_recurrent_context::get_n_rs() const {
return is_full ? mem->size : mem->n; return is_full ? mem->size : mem->n;
} }
uint32_t llama_memory_recurrent_state::get_head() const { uint32_t llama_memory_recurrent_context::get_head() const {
return is_full ? 0 : mem->head; return is_full ? 0 : mem->head;
} }
int32_t llama_memory_recurrent_state::get_rs_z() const { int32_t llama_memory_recurrent_context::get_rs_z() const {
return is_full ? 0 : mem->rs_z; return is_full ? 0 : mem->rs_z;
} }
uint32_t llama_memory_recurrent_state::get_size() const { uint32_t llama_memory_recurrent_context::get_size() const {
return mem->size; return mem->size;
} }
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const { ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
return mem->r_l[il]; return mem->r_l[il];
} }
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const { ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
return mem->s_l[il]; return mem->s_l[il];
} }
int32_t llama_memory_recurrent_state::s_copy(int i) const { int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + mem->head].src0; return mem->cells[i + mem->head].src0;
} }

View File

@ -11,8 +11,8 @@
// llama_memory_recurrent // llama_memory_recurrent
// //
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it // see the implementation of llama_kv_cache_unified_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i { class llama_memory_recurrent : public llama_memory_i {
public: public:
@ -34,14 +34,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
const llama_batch & batch, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
void clear(bool data) override; void clear(bool data) override;
@ -125,37 +125,34 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
class llama_memory_recurrent_state : public llama_memory_state_i { class llama_memory_recurrent_context : public llama_memory_context_i {
public: public:
// used for errors // used for errors
llama_memory_recurrent_state(llama_memory_status status); llama_memory_recurrent_context(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache or update context
llama_memory_recurrent_state( llama_memory_recurrent_context(
llama_memory_recurrent * mem); llama_memory_recurrent * mem);
// used to create a state from a batch // used to create a batch processing context from a batch
llama_memory_recurrent_state( llama_memory_recurrent_context(
llama_memory_recurrent * mem, llama_memory_recurrent * mem,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_memory_recurrent_state(); virtual ~llama_memory_recurrent_context();
// //
// llama_memory_state_i // llama_memory_context_i
// //
bool next() override; bool next() override;
bool apply() override; bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override; llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_memory_recurrent_state specific API // llama_memory_recurrent_context specific API
// //
uint32_t get_n_rs() const; uint32_t get_n_rs() const;
@ -173,8 +170,6 @@ private:
llama_memory_recurrent * mem; llama_memory_recurrent * mem;
llama_sbatch sbatch;
size_t i_next = 0; size_t i_next = 0;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;

View File

@ -3,10 +3,11 @@
#include "llama.h" #include "llama.h"
#include <memory> #include <memory>
#include <vector>
struct llama_ubatch; struct llama_ubatch;
class llama_batch_allocr;
class llama_io_write_i; class llama_io_write_i;
class llama_io_read_i; class llama_io_read_i;
@ -26,23 +27,21 @@ enum llama_memory_status {
LLAMA_MEMORY_STATUS_FAILED_COMPUTE, LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
}; };
// helper function for combining the status of two memory states // helper function for combining the status of two memory contexts
// useful for implementing hybrid memory types (e.g. iSWA) // useful for implementing hybrid memory types (e.g. iSWA)
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
// the interface for managing the memory state during batch processing // the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see: // this interface is implemented per memory type. see:
// - llama_kv_cache_unified_state // - llama_kv_cache_unified_context
// - llama_kv_cache_unified_iswa_state // - llama_kv_cache_unified_iswa_context
// ... // ...
// //
// the only method that can mutate the memory and the memory state is llama_memory_i::apply() // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
// struct llama_memory_context_i {
// TODO: rename to llama_memory_context_i ? virtual ~llama_memory_context_i() = default;
struct llama_memory_state_i {
virtual ~llama_memory_state_i() = default;
// consume the current ubatch from the state and proceed to the next one // consume the current ubatch from the context and proceed to the next one
// return false if we are done // return false if we are done
virtual bool next() = 0; virtual bool next() = 0;
@ -50,17 +49,14 @@ struct llama_memory_state_i {
// return false on failure // return false on failure
virtual bool apply() = 0; virtual bool apply() = 0;
// TODO: this might get reworked in the future when refactoring llama_batch
virtual std::vector<int64_t> & out_ids() = 0;
// get the current ubatch // get the current ubatch
virtual const llama_ubatch & get_ubatch() const = 0; virtual const llama_ubatch & get_ubatch() const = 0;
// get the status of the memory state - used for error handling and checking if any updates would be applied // get the status of the memory context - used for error handling and checking if any updates would be applied
virtual llama_memory_status get_status() const = 0; virtual llama_memory_status get_status() const = 0;
}; };
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>; using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory // general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types // the KV cache is a type of LLM memory, but there can be other types
@ -68,19 +64,19 @@ struct llama_memory_i {
virtual ~llama_memory_i() = default; virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache // split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them // return a context object containing the ubatches and memory state required to process them
// check the llama_memory_state_i::get_status() for the result // check the llama_memory_context_i::get_status() for the result
virtual llama_memory_state_ptr init_batch( virtual llama_memory_context_ptr init_batch(
const llama_batch & batch, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) = 0; bool embd_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers // simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0; virtual llama_memory_context_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc. // prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters // getters
virtual bool get_can_shift() const = 0; virtual bool get_can_shift() const = 0;

View File

@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???); // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos()); add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos()); add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix()); add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces()); add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap()); add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,4 @@
#include "llama-quant.h" #include "llama-quant.h"
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-model.h" #include "llama-model.h"
#include "llama-model-loader.h" #include "llama-model-loader.h"
@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
} }
} }
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
if (prune.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const int blk = std::stoi(match[1]);
std::string new_name = orig_name;
if (mapped.count(blk)) {
// Already mapped, do nothing
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
mapped[blk] = "";
} else if (blk < prune.front()) {
mapped[blk] = std::to_string(blk);
next_id = blk + 1;
} else {
mapped[blk] = std::to_string(next_id);
++next_id;
}
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
}
return orig_name;
}
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
if (mapped.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const std::string blk(match[1]);
std::string new_name = orig_name;
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
}
return orig_name;
}
struct quantize_state_impl { struct quantize_state_impl {
const llama_model & model; const llama_model & model;
const llama_model_quantize_params * params; const llama_model_quantize_params * params;
@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const size_t align = GGUF_DEFAULT_ALIGNMENT; const size_t align = GGUF_DEFAULT_ALIGNMENT;
gguf_context_ptr ctx_out { gguf_init_empty() }; gguf_context_ptr ctx_out { gguf_init_empty() };
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
}
// copy the KV pairs from the input file // copy the KV pairs from the input file
gguf_set_kv (ctx_out.get(), ml.meta.get()); gguf_set_kv (ctx_out.get(), ml.meta.get());
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} }
} }
std::map<int, std::string> mapped;
int blk_id = 0;
int pruned_attention_w = 0;
// make a list of weights // make a list of weights
std::vector<const llama_model_loader::llama_tensor_weight *> tensors; std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
tensors.reserve(ml.weights_map.size()); tensors.reserve(ml.weights_map.size());
for (const auto & it : ml.weights_map) { for (const auto & it : ml.weights_map) {
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
if (remapped_name.empty()) {
if (it.first.find("attn_v.weight") != std::string::npos ||
it.first.find("attn_qkv.weight") != std::string::npos ||
it.first.find("attn_kv_b.weight") != std::string::npos) {
pruned_attention_w++;
}
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
continue;
} else if (remapped_name != it.first) {
ggml_set_name(it.second.tensor, remapped_name.c_str());
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
}
tensors.push_back(&it.second); tensors.push_back(&it.second);
} }
if (!prune_list.empty()) {
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
}
// keep_split requires that the weights are sorted by split index // keep_split requires that the weights are sorted by split index
if (params->keep_split) { if (params->keep_split) {
@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (llama_model_has_encoder(&model)) { if (llama_model_has_encoder(&model)) {
n_attn_layer *= 3; n_attn_layer *= 3;
} }
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
} }
size_t total_size_org = 0; size_t total_size_org = 0;
@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
for (size_t i = 0; i < ctx_outs.size(); ++i) { for (size_t i = 0; i < ctx_outs.size(); ++i) {
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
} }
} }
@ -832,7 +906,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const float * imatrix = nullptr; const float * imatrix = nullptr;
if (imatrix_data) { if (imatrix_data) {
auto it = imatrix_data->find(tensor->name); auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
if (it == imatrix_data->end()) { if (it == imatrix_data->end()) {
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
} else { } else {
@ -947,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.imatrix =*/ nullptr, /*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr, /*.kv_overrides =*/ nullptr,
/*.tensor_type =*/ nullptr, /*.tensor_type =*/ nullptr,
/*.prune_layers =*/ nullptr
}; };
return result; return result;

View File

@ -1269,6 +1269,7 @@ struct llama_vocab::impl {
bool add_space_prefix = false; bool add_space_prefix = false;
bool add_bos = false; bool add_bos = false;
bool add_eos = false; bool add_eos = false;
bool add_sep = false;
bool ignore_merges = false; bool ignore_merges = false;
bool clean_spaces = false; // clean_up_tokenization_spaces bool clean_spaces = false; // clean_up_tokenization_spaces
bool remove_extra_whitespaces = false; bool remove_extra_whitespaces = false;
@ -1421,6 +1422,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_sep_id = 102; special_sep_id = 102;
special_pad_id = 0; special_pad_id = 0;
special_mask_id = 103; special_mask_id = 103;
add_sep = true;
} else if (tokenizer_model == "gpt2") { } else if (tokenizer_model == "gpt2") {
type = LLAMA_VOCAB_TYPE_BPE; type = LLAMA_VOCAB_TYPE_BPE;
@ -1550,12 +1553,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "jina-es" || tokenizer_pre == "jina-es" ||
tokenizer_pre == "jina-de" || tokenizer_pre == "jina-de" ||
tokenizer_pre == "gigachat" || tokenizer_pre == "gigachat" ||
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-de") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if (
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-code" || tokenizer_pre == "jina-v2-code" ||
tokenizer_pre == "roberta-bpe") { tokenizer_pre == "roberta-bpe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
add_sep = true;
} else if ( } else if (
tokenizer_pre == "refact") { tokenizer_pre == "refact") {
pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
@ -1665,6 +1671,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
clean_spaces = true; clean_spaces = true;
add_bos = true; add_bos = true;
add_eos = false; add_eos = false;
add_sep = true;
} else if (type == LLAMA_VOCAB_TYPE_UGM) { } else if (type == LLAMA_VOCAB_TYPE_UGM) {
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
add_bos = false; add_bos = false;
@ -1801,7 +1808,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} }
} }
// Handle add_bos and add_eos // Handle add_bos, add_eos and add_sep
{ {
bool temp = true; bool temp = true;
@ -1811,6 +1818,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
add_eos = temp; add_eos = temp;
} }
if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) {
add_sep = temp;
}
} }
// auto-detect special tokens by text // auto-detect special tokens by text
@ -2060,9 +2070,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
//NOTE: Per token attributes are missing from the GGUF file. //NOTE: Per token attributes are missing from the GGUF file.
//TODO: Extract attributes from GGUF file. //TODO: Extract attributes from GGUF file.
{ {
auto _contains_any = [] (const std::string & str, const std::vector<std::string> & substrs) -> bool { auto _contains_any = [] (const std::string & str, const std::vector<std::string_view> & substrs) -> bool {
for (const auto & substr : substrs) { for (const auto & substr : substrs) {
if (str.find(substr) < std::string::npos) { if (str.find(substr) != std::string::npos) {
return true; return true;
} }
} }
@ -3000,6 +3010,10 @@ bool llama_vocab::get_add_eos() const {
return pimpl->add_eos; return pimpl->add_eos;
} }
bool llama_vocab::get_add_sep() const {
return pimpl->add_sep;
}
bool llama_vocab::get_ignore_merges() const { bool llama_vocab::get_ignore_merges() const {
return pimpl->ignore_merges; return pimpl->ignore_merges;
} }
@ -3060,6 +3074,11 @@ int32_t llama_vocab::tokenize(
bool add_special, bool add_special,
bool parse_special) const { bool parse_special) const {
auto res = tokenize(std::string(text, text_len), add_special, parse_special); auto res = tokenize(std::string(text, text_len), add_special, parse_special);
if (res.size() >= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
LLAMA_LOG_ERROR("%s: tokenization result size %zu exceeds int32_t limit\n", __func__, res.size());
return std::numeric_limits<int32_t>::min();
}
if (n_tokens_max < (int) res.size()) { if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size()); return -((int) res.size());
@ -3191,6 +3210,10 @@ bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
return vocab->get_add_eos(); return vocab->get_add_eos();
} }
bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
return vocab->get_add_sep();
}
llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) { llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
return vocab->token_fim_pre(); return vocab->token_fim_pre();
} }

View File

@ -74,6 +74,7 @@ struct llama_vocab {
bool get_add_space_prefix () const; bool get_add_space_prefix () const;
bool get_add_bos () const; bool get_add_bos () const;
bool get_add_eos () const; bool get_add_eos () const;
bool get_add_sep () const;
bool get_ignore_merges () const; bool get_ignore_merges () const;
bool get_clean_spaces () const; bool get_clean_spaces () const;
bool get_remove_extra_whitespaces () const; bool get_remove_extra_whitespaces () const;

View File

@ -204,12 +204,17 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
// disable C++17 deprecation warning for std::codecvt_utf8 // disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push # pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations" # pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif #endif
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv; std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic pop # pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif #endif
return conv.from_bytes(s); return conv.from_bytes(s);

View File

@ -2755,6 +2755,35 @@ struct test_conv_transpose_1d : public test_case {
} }
}; };
// GGML_OP_CONV_TRANSPOSE_2D
struct test_conv_transpose_2d : public test_case {
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
const int stride;
std::string vars() override {
return VARS_TO_STR3(ne_input, ne_kernel, stride);
}
test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int stride = 1)
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_set_name(input, "input");
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");
ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_IM2COL // GGML_OP_IM2COL
struct test_im2col : public test_case { struct test_im2col : public test_case {
const ggml_type type_input; const ggml_type type_input;
@ -4080,6 +4109,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1})); test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1})); test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
@ -4649,6 +4681,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
return test_cases; return test_cases;
} }

View File

@ -267,6 +267,7 @@ struct cmd_params {
int delay; int delay;
bool verbose; bool verbose;
bool progress; bool progress;
bool no_warmup;
output_formats output_format; output_formats output_format;
output_formats output_format_stderr; output_formats output_format_stderr;
}; };
@ -303,6 +304,7 @@ static const cmd_params cmd_params_defaults = {
/* delay */ 0, /* delay */ 0,
/* verbose */ false, /* verbose */ false,
/* progress */ false, /* progress */ false,
/* no_warmup */ false,
/* output_format */ MARKDOWN, /* output_format */ MARKDOWN,
/* output_format_stderr */ NONE, /* output_format_stderr */ NONE,
}; };
@ -325,6 +327,7 @@ static void print_usage(int /* argc */, char ** argv) {
output_format_str(cmd_params_defaults.output_format_stderr)); output_format_str(cmd_params_defaults.output_format_stderr));
printf(" -v, --verbose verbose output\n"); printf(" -v, --verbose verbose output\n");
printf(" --progress print test progress indicators\n"); printf(" --progress print test progress indicators\n");
printf(" --no-warmup skip warmup runs before benchmarking\n");
printf("\n"); printf("\n");
printf("test parameters:\n"); printf("test parameters:\n");
printf(" -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); printf(" -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
@ -425,6 +428,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
params.prio = cmd_params_defaults.prio; params.prio = cmd_params_defaults.prio;
params.delay = cmd_params_defaults.delay; params.delay = cmd_params_defaults.delay;
params.progress = cmd_params_defaults.progress; params.progress = cmd_params_defaults.progress;
params.no_warmup = cmd_params_defaults.no_warmup;
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
arg = argv[i]; arg = argv[i];
@ -798,6 +802,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
params.verbose = true; params.verbose = true;
} else if (arg == "--progress") { } else if (arg == "--progress") {
params.progress = true; params.progress = true;
} else if (arg == "--no-warmup") {
params.no_warmup = true;
} else { } else {
invalid_param = true; invalid_param = true;
break; break;
@ -1925,6 +1931,7 @@ int main(int argc, char ** argv) {
llama_attach_threadpool(ctx, threadpool, NULL); llama_attach_threadpool(ctx, threadpool, NULL);
// warmup run // warmup run
if (!params.no_warmup) {
if (t.n_prompt > 0) { if (t.n_prompt > 0) {
if (params.progress) { if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
@ -1946,6 +1953,7 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
} }
}
for (int i = 0; i < params.reps; i++) { for (int i = 0; i < params.reps; i++) {
llama_memory_clear(llama_get_memory(ctx), false); llama_memory_clear(llama_get_memory(ctx), false);

View File

@ -292,6 +292,7 @@ int main(int argc, char ** argv) {
if (!params.system_prompt.empty() || !params.prompt.empty()) { if (!params.system_prompt.empty() || !params.prompt.empty()) {
common_chat_templates_inputs inputs; common_chat_templates_inputs inputs;
inputs.use_jinja = g_params->use_jinja;
inputs.messages = chat_msgs; inputs.messages = chat_msgs;
inputs.add_generation_prompt = !params.prompt.empty(); inputs.add_generation_prompt = !params.prompt.empty();

View File

@ -2211,6 +2211,9 @@ struct clip_model_loader {
{ {
hparams.rope_theta = 10000.0f; hparams.rope_theta = 10000.0f;
hparams.warmup_image_size = hparams.patch_size * 8; hparams.warmup_image_size = hparams.patch_size * 8;
// Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM
// ref: https://github.com/ggml-org/llama.cpp/issues/14310
hparams.image_size = 1024;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break; } break;
case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3:

View File

@ -107,13 +107,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
return false; return false;
} }
// usage:
// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
//
[[noreturn]] [[noreturn]]
static void usage(const char * executable) { static void usage(const char * executable) {
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable); printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n");
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
@ -124,6 +122,8 @@ static void usage(const char * executable) {
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n");
printf(" Advanced option to remove all tensors from the given layers\n");
printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n");
printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
@ -286,6 +286,32 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
return true; return true;
} }
static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers) {
if (!data) {
printf("\n%s: no layer pruning ids provided\n\n", __func__);
return false;
}
const auto block_ids = string_split<std::string>(data, ',');
for (const auto & block_id : block_ids) {
int id;
try {
id = std::stoi(block_id);
} catch (...) {
id = -1;
}
if (id < 0) {
printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str());
return false;
}
prune_layers.emplace_back(id);
}
sort(prune_layers.begin(), prune_layers.end());
prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end());
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
if (argc < 3) { if (argc < 3) {
usage(argv[0]); usage(argv[0]);
@ -298,6 +324,7 @@ int main(int argc, char ** argv) {
std::vector<std::string> included_weights, excluded_weights; std::vector<std::string> included_weights, excluded_weights;
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
std::vector<tensor_quantization> tensor_types; std::vector<tensor_quantization> tensor_types;
std::vector<int> prune_layers;
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@ -324,6 +351,10 @@ int main(int argc, char ** argv) {
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
usage(argv[0]); usage(argv[0]);
} }
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) { } else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
usage(argv[0]); usage(argv[0]);
@ -411,6 +442,9 @@ int main(int argc, char ** argv) {
if (!tensor_types.empty()) { if (!tensor_types.empty()) {
params.tensor_types = &tensor_types; params.tensor_types = &tensor_types;
} }
if (!prune_layers.empty()) {
params.prune_layers = &prune_layers;
}
llama_backend_init(); llama_backend_init();

View File

@ -9,6 +9,9 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#if defined(_WIN32) #if defined(_WIN32)
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h> # include <windows.h>
# include <io.h> # include <io.h>
#else #else
@ -939,17 +942,30 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
// Function to tokenize the prompt // Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) { std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0; const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1;
int n_tokens = prompt.size() + 2 * is_first;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_tokens);
prompt_tokens.resize(n_prompt_tokens); n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, prompt_tokens.data(), prompt_tokens.size(),
true) < 0) { is_first, /*parse_special =*/true);
printe("failed to tokenize the prompt\n"); if (n_tokens == std::numeric_limits<int32_t>::min()) {
printe("tokenization failed: input too large\n");
return -1; return -1;
} }
if (n_tokens < 0) {
return n_prompt_tokens; prompt_tokens.resize(-n_tokens);
int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
prompt_tokens.data(), prompt_tokens.size(),
is_first, /*parse_special =*/true);
if (check != -n_tokens) {
printe("failed to tokenize the prompt (size mismatch)\n");
return -1;
}
n_tokens = check;
} else {
prompt_tokens.resize(n_tokens);
}
return n_tokens;
} }
// Check if we have enough space in the context to evaluate this batch // Check if we have enough space in the context to evaluate this batch

View File

@ -187,6 +187,8 @@ The project is under active development, and we are [looking for feedback and co
| `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices | | `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_MODEL_DRAFT) | | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_MODEL_DRAFT) |
| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) |
| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) |
| `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) |
| `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall |
| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) | | `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) |

View File

@ -1969,10 +1969,8 @@ struct server_context {
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1; params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
// force F16 KV cache for the draft model for extra performance params_dft.cache_type_v = params_base.speculative.cache_type_v;
params_dft.cache_type_k = GGML_TYPE_F16;
params_dft.cache_type_v = GGML_TYPE_F16;
llama_init_dft = common_init_from_params(params_dft); llama_init_dft = common_init_from_params(params_dft);
@ -3387,38 +3385,6 @@ struct server_context {
llama_set_embeddings(ctx, slot_batched->need_embd()); llama_set_embeddings(ctx, slot_batched->need_embd());
} }
// pad the batch so that batch.n_tokens >= n_slots
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
if (slot_batched->need_embd()) {
const int n_slots = slots.size();
if (batch.n_tokens < n_slots) {
std::set<llama_seq_id> seq_ids;
for (int j = 0; j < batch.n_tokens; ++j) {
seq_ids.insert(batch.seq_id[j][0]);
}
// find unused sequence id
llama_seq_id seq_id = -1;
for (int i = 0; i < n_slots; ++i) {
if (seq_ids.find(i) == seq_ids.end()) {
seq_id = i;
}
}
const int n_add = n_slots - batch.n_tokens;
SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
for (int j = 0; j < n_add; ++j) {
common_batch_add(batch, 0, j, { seq_id }, true);
}
slots[seq_id].cache_tokens.clear();
llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1);
}
}
int32_t i_next = 0; int32_t i_next = 0;
// process the created batch of tokens // process the created batch of tokens
@ -3452,9 +3418,12 @@ struct server_context {
} }
if (ret < -1) { if (ret < -1) {
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
err = "Compute error."; err = "Compute error.";
} }
// TODO: handle ret == 2 (abort) when we start aborting
if (!err.empty()) { if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) { for (auto & slot : slots) {

View File

@ -271,12 +271,20 @@ static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_
} }
result.reserve(doc.size() + query.size() + 4); result.reserve(doc.size() + query.size() + 4);
if (llama_vocab_get_add_bos(vocab)) {
result.push_back(llama_vocab_bos(vocab)); result.push_back(llama_vocab_bos(vocab));
}
result.insert(result.end(), query.begin(), query.end()); result.insert(result.end(), query.begin(), query.end());
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token); result.push_back(eos_token);
}
if (llama_vocab_get_add_sep(vocab)) {
result.push_back(llama_vocab_sep(vocab)); result.push_back(llama_vocab_sep(vocab));
}
result.insert(result.end(), doc.begin(), doc.end()); result.insert(result.end(), doc.begin(), doc.end());
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token); result.push_back(eos_token);
}
return result; return result;
} }