Add tensor type and depth heuristics
This commit is contained in:
parent
b7911f1431
commit
a6853ea2ae
|
|
@ -16,6 +16,7 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
// Quantization types. Changes to this struct must be replicated in quantize.cpp
|
// Quantization types. Changes to this struct must be replicated in quantize.cpp
|
||||||
struct tensor_quantization {
|
struct tensor_quantization {
|
||||||
|
|
@ -685,13 +686,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
GGML_TYPE_F16
|
GGML_TYPE_F16
|
||||||
};
|
};
|
||||||
|
|
||||||
const char * important_tensors[] = {
|
|
||||||
".output.weight",
|
|
||||||
".attn_output.weight",
|
|
||||||
".ffn_down.weight",
|
|
||||||
".ffn_down_shexp.weight"
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr double epsilon = 1e-12;
|
constexpr double epsilon = 1e-12;
|
||||||
constexpr double infinity = std::numeric_limits<double>::infinity();
|
constexpr double infinity = std::numeric_limits<double>::infinity();
|
||||||
constexpr uint32_t file_magic = 0x42505731; // BPW1
|
constexpr uint32_t file_magic = 0x42505731; // BPW1
|
||||||
|
|
@ -1544,11 +1538,89 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
return emit_overrides();
|
return emit_overrides();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto is_important = [&](const std::string & tensor_name) -> bool {
|
auto tensor_importance = [&](const std::vector<tensor_info> & all_vec) -> std::unordered_map<std::string, float> {
|
||||||
return std::any_of(std::begin(important_tensors), std::end(important_tensors), [&](const char* imp) {
|
std::unordered_map<std::string, float> scores;
|
||||||
return tensor_name.find(imp) != std::string::npos;
|
for (const auto & ti : all_vec) {
|
||||||
|
const std::string name = ggml_get_name(ti.w->tensor);
|
||||||
|
float total_score = 0.0f;
|
||||||
|
float depth_score = 0.0f;
|
||||||
|
float type_score = 0.0f;
|
||||||
|
|
||||||
|
// Depth component: output, embeddings & early/late layers are important
|
||||||
|
if (name.find("output.weight") != std::string::npos ||
|
||||||
|
name.find("token_embd.weight") != std::string::npos) {
|
||||||
|
depth_score = 1.0f;
|
||||||
}
|
}
|
||||||
);
|
else if (name.find(".attn_output.weight") != std::string::npos) {
|
||||||
|
depth_score = 0.9f;
|
||||||
|
} else {
|
||||||
|
static const std::regex layer_pattern(R"(blk\.(\d+)\.)");
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(name, match, layer_pattern)) {
|
||||||
|
const int layer = std::stoi(match[1]);
|
||||||
|
const float normalized_layer = (float)layer / (float)std::max(1, (int)model.hparams.n_layer - 1);
|
||||||
|
const float center_dist = std::abs(normalized_layer - 0.5f) * 2.0f;
|
||||||
|
depth_score = 0.2f + 0.6f * center_dist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type component: certain tensor types are more important
|
||||||
|
if (name.find("output.weight") != std::string::npos) {
|
||||||
|
type_score = 1.0f;
|
||||||
|
} else if (name.find(".attn_output.weight") != std::string::npos) {
|
||||||
|
type_score = 0.9f;
|
||||||
|
} else if (name.find(".ffn_down.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_down_shexp.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_down_exps.weight") != std::string::npos) {
|
||||||
|
type_score = 0.8f;
|
||||||
|
} else if (name.find(".attn_q.weight") != std::string::npos ||
|
||||||
|
name.find(".attn_k.weight") != std::string::npos ||
|
||||||
|
name.find(".attn_v.weight") != std::string::npos ||
|
||||||
|
name.find(".attn_qkv.weight") != std::string::npos) {
|
||||||
|
type_score = 0.7f;
|
||||||
|
} else if (name.find(".ffn_up.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_gate.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_up_shexp.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_gate_shexp.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_up_exps.weight") != std::string::npos ||
|
||||||
|
name.find(".ffn_gate_exps.weight") != std::string::npos) {
|
||||||
|
type_score = 0.6f;
|
||||||
|
} else if (name.find("token_embd.weight") != std::string::npos) {
|
||||||
|
type_score = 0.5f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Weighted combination
|
||||||
|
total_score = 0.80f * type_score + 0.20f * depth_score; // 80% type + 20% depth
|
||||||
|
scores[name] = total_score;
|
||||||
|
}
|
||||||
|
|
||||||
|
return scores;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto select_tensors = [&](const std::vector<tensor_info> & all_vec) -> std::unordered_set<std::string> {
|
||||||
|
const auto scores = tensor_importance(all_vec);
|
||||||
|
|
||||||
|
// Sort by score
|
||||||
|
std::vector<std::pair<std::string, float>> sorted_scores(scores.begin(), scores.end());
|
||||||
|
std::sort(sorted_scores.begin(), sorted_scores.end(), [](const auto & a, const auto & b) { return a.second > b.second; });
|
||||||
|
|
||||||
|
// Select top percentile
|
||||||
|
const size_t n_important = std::max<size_t>(1, std::llround((double)sorted_scores.size() * 0.25f)); // top 25%
|
||||||
|
|
||||||
|
std::unordered_set<std::string> important;
|
||||||
|
for (size_t i = 0; i < std::min(n_important, sorted_scores.size()); ++i) {
|
||||||
|
important.insert(sorted_scores[i].first);
|
||||||
|
//LLAMA_LOG_DEBUG("\t%s: important tensor %s (score %.4f)\n", func, sorted_scores[i].first.c_str(), sorted_scores[i].second);
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: prioritizing %zu out off %zu tensors\n", func, important.size(), sorted_scores.size());
|
||||||
|
return important;
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto important_set = select_tensors(all);
|
||||||
|
|
||||||
|
auto is_important = [&](const std::string & tensor_name) -> bool {
|
||||||
|
return important_set.count(tensor_name) > 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Lagrangian relaxation to minimise error subject to a bpw target constraint
|
// Lagrangian relaxation to minimise error subject to a bpw target constraint
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue