This commit is contained in:
Rohanjames1997 2025-12-17 05:51:06 +02:00 committed by GitHub
commit 673d283baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 3 deletions

View File

@ -79,7 +79,7 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) {
}
// Total dot product error
static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {
static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2, const int nrc) {
GGML_UNUSED(qfns);
std::vector<uint8_t> tmp_q1(2*test_size);
@ -91,7 +91,7 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr
vdot->from_float(test_data2, tmp_q2.data(), test_size);
float result = INFINITY;
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, nrc);
const float dot_ref = dot_product(test_data1, test_data2, test_size);
@ -163,7 +163,7 @@ int main(int argc, char * argv[]) {
printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
}
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 1);
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
? MAX_DOT_PRODUCT_ERROR_LOWBIT
@ -175,6 +175,16 @@ int main(int argc, char * argv[]) {
if (failed || verbose) {
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
}
// Test nrc=2 path for types that support it
if (qfns_cpu->nrows == 2) {
const float vec_dot_error_nrc2 = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 2);
failed = !(vec_dot_error_nrc2 < max_allowed_error);
num_failed += failed;
if (failed || verbose) {
printf("%5s dot product error (nrc=2): %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error_nrc2);
}
}
}
}