diff --git a/src/llama-ext.h b/src/llama-ext.h index 13ced783b4..a29dcf1536 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -1,8 +1,14 @@ #pragma once -#include "llama-context.h" -#include "ggml.h" -#include "stdint.h" +#include "llama.h" + +// TODO: try to remove this headers +#include "llama-arch.h" +#include "llama-model.h" +#include "llama-quant.h" + +#include +#include // Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. LLAMA_API struct ggml_cgraph * llama_graph_reserve( @@ -10,3 +16,29 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); + +LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype); + +// TODO: use llama_quant_ prefix to name these consistently: + +// Returns true if this tensor should be quantized (based on name, dims, params). +LLAMA_API bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_arch arch, const ggml_tensor * tensor); + +// TODO: add: +// LLAMA_API llama_quant * llama_quant_init(...); +// LLAMA_API void llama_quant_free(llama_quant * qnt); + +// TODO: become member function of llama_quant +LLAMA_API ggml_type llama_tensor_get_type( + llama_quant & qs, + const llama_model_quantize_params * params, + const ggml_tensor * tensor, + ggml_type default_type, + const tensor_metadata & tm); + +// Initialize llama_quant counters and populate tensor_metadata categories. +// metadata: vector with name fields already set, will have category field populated. +// TODO: become member function of llama_quant +LLAMA_API void init_quantize_state_counters( + llama_quant & qs, + std::vector & metadata); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a103b15a4c..d1a9d6d15c 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -2,6 +2,7 @@ #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" +#include "llama-ext.h" #include #include @@ -138,7 +139,7 @@ struct compiled_tensor_type_patterns { std::vector> patterns; }; -quantize_state_impl::quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) +llama_quant::llama_quant(const llama_model & model, const llama_model_quantize_params * params) : model(model), params(params) { if (params->tensor_types) { @@ -152,7 +153,7 @@ quantize_state_impl::quantize_state_impl(const llama_model & model, const llama_ } } -quantize_state_impl::~quantize_state_impl() = default; +llama_quant::~llama_quant() = default; // // dequantization @@ -302,7 +303,7 @@ bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_ // // incompatible tensor shapes are handled here - fallback to a compatible type -static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tensor * t, const ggml_type target_type) { +static ggml_type tensor_type_fallback(llama_quant & qs, const ggml_tensor * t, const ggml_type target_type) { ggml_type return_type = target_type; const int64_t ncols = t->ne[0]; @@ -351,7 +352,7 @@ static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tenso } // internal standard logic for selecting the target tensor type based on tensor category, ftype, and model arch -static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype, tensor_category category) { +static ggml_type llama_tensor_get_type_impl(llama_quant & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype, tensor_category category) { const std::string name = ggml_get_name(tensor); // TODO: avoid hardcoded tensor names - use the TN_* constants @@ -601,7 +602,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type } // outer wrapper: determine the ggml_type that this tensor should be quantized to -ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm) { +ggml_type llama_tensor_get_type(llama_quant & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm) { if (!tensor_allows_quantization(params, qs.model.arch, tensor)) { return tensor->type; } @@ -776,7 +777,7 @@ ggml_type llama_ftype_get_default_type(llama_ftype ftype) { } -void init_quantize_state_counters(quantize_state_impl & qs, std::vector & metadata) { +void init_quantize_state_counters(llama_quant & qs, std::vector & metadata) { for (auto & tm : metadata) { tensor_category cat = tensor_get_category(tm.name); tm.category = cat; @@ -835,7 +836,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: model.load_hparams(ml); model.load_stats (ml); - quantize_state_impl qs(model, params); + llama_quant qs(model, params); if (params->only_copy) { ftype = ml.ftype; diff --git a/src/llama-quant.h b/src/llama-quant.h index b32ce2d92d..4a7800f98d 100644 --- a/src/llama-quant.h +++ b/src/llama-quant.h @@ -2,16 +2,13 @@ #include "llama.h" -#include "ggml.h" - -#include "llama-arch.h" - #include #include -#include struct llama_model; +// TODO: use llama_quant_ prefix to name these consistently: + // tensor categorization - used to avoid repeated string matching in quantization logic. // this is different from LLM_TN - we want broad categories, not specific tensor names per arch. enum class tensor_category { @@ -30,6 +27,7 @@ enum class tensor_category { }; // per-tensor metadata, computed in the preliminary loop and used in the main loop +// TODO: probably should belong to llama_quant struct tensor_metadata { std::string name; ggml_type target_type; @@ -48,7 +46,7 @@ struct tensor_type_option { struct compiled_tensor_type_patterns; -struct quantize_state_impl { +struct llama_quant { const llama_model & model; const llama_model_quantize_params * params; @@ -72,16 +70,6 @@ struct quantize_state_impl { // tensor type override patterns (compiled once, used in llama_tensor_get_type) std::unique_ptr tensor_type_patterns; - quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params); - ~quantize_state_impl(); + llama_quant(const llama_model & model, const llama_model_quantize_params * params); + ~llama_quant(); }; - -ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm); -ggml_type llama_ftype_get_default_type(llama_ftype ftype); - -// Initialize quantize_state_impl counters and populate tensor_metadata categories. -// metadata: vector with name fields already set, will have category field populated. -void init_quantize_state_counters(quantize_state_impl & qs, std::vector & metadata); - -// Returns true if this tensor should be quantized (based on name, dims, params). -bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_arch arch, const ggml_tensor * tensor); diff --git a/tests/test-quant-type-selection.cpp b/tests/test-quant-type-selection.cpp index 7dac3b6659..7fbc6ce611 100644 --- a/tests/test-quant-type-selection.cpp +++ b/tests/test-quant-type-selection.cpp @@ -1,11 +1,9 @@ -#include "../src/llama-arch.h" -#include "../src/llama-model.h" -#include "../src/llama-quant.h" -#include "ggml-cpp.h" -#include "ggml.h" -#include "gguf-model-data.h" #include "llama.h" +#include "../src/llama-ext.h" + +#include "gguf-model-data.h" + #include #include #include @@ -323,13 +321,15 @@ static std::string read_file_contents(const std::string & path) { // --------------------------------------------------------------------------- // Returns {tensor_name, assigned_type} for each tensor, in order. +// TODO: should likely be moved as a member function of llama_quant and expose through the `llama-ext.h` interface static std::vector> compute_quant_types(llama_model & mdl, const std::vector & tensors, llama_ftype ftype) { llama_model_quantize_params qparams = llama_model_quantize_default_params(); qparams.ftype = ftype; - quantize_state_impl qs(mdl, &qparams); + // TODO: call llama_quant_init(...) + llama_quant qs(mdl, &qparams); std::vector metadata(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) {