diff --git a/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp index d775a03638..5ee27dbf7d 100644 --- a/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp @@ -260,18 +260,52 @@ void test_x86_is() { } #endif -static int ggml_backend_cpu_x86_score() { - // FIXME: this does not check for OS support +struct os_support_x86 { + // Check if the OS supports the extended registers states (check for XMM is superflous) + bool YMM() { return (xcr0 & 0x06) == 0x06; } // FMA, F16C, AVX, AVX2 + bool ZMM() { return (xcr0 & 0xe6) == 0xe6; } // AVX512 + bool TMM() { return (xcr0 & 0x60000) == 0x60000; } // AMX +#ifdef _MSC_VER + static uint64_t xgetbv(uint32_t xcr) { + return _xgetbv(xcr); + } +#else + static uint64_t xgetbv(uint32_t xcr) { + uint32_t eax, edx; + __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(xcr)); + return ((uint64_t)edx << 32) | eax; + } +#endif + // osxsave flag is needed to check availability of XGETBV instruction + os_support_x86(bool has_osxsave) { + xcr0 = has_osxsave ? xgetbv(0) : 0; + } + + uint64_t xcr0; +}; + +#if 0 +void test_os_support_x86() { + cpuid_x86 is; + os_support_x86 os(is.OSXSAVE()); + printf("YMM support: %d", os.YMM()); + printf("ZMM support: %d", os.ZMM()); + printf("TMM support: %d", os.TMM()); +} +#endif + +static int ggml_backend_cpu_x86_score() { int score = 1; cpuid_x86 is; + os_support_x86 os(is.OSXSAVE()); #ifdef GGML_FMA - if (!is.FMA()) { return 0; } + if (!is.FMA() || !os.YMM()) { return 0; } score += 1; #endif #ifdef GGML_F16C - if (!is.F16C()) { return 0; } + if (!is.F16C() || !os.YMM()) { return 0; } score += 1<<1; #endif #ifdef GGML_SSE42 @@ -283,18 +317,19 @@ static int ggml_backend_cpu_x86_score() { score += 1<<3; #endif #ifdef GGML_AVX - if (!is.AVX()) { return 0; } + if (!is.AVX() || !os.YMM()) { return 0; } score += 1<<4; #endif #ifdef GGML_AVX2 - if (!is.AVX2()) { return 0; } + if (!is.AVX2() || !os.YMM()) { return 0; } score += 1<<5; #endif #ifdef GGML_AVX_VNNI - if (!is.AVX_VNNI()) { return 0; } + if (!is.AVX_VNNI() || !os.YMM()) { return 0; } score += 1<<6; #endif #ifdef GGML_AVX512 + if (!os.ZMM()) { return 0; } if (!is.AVX512F()) { return 0; } if (!is.AVX512CD()) { return 0; } if (!is.AVX512VL()) { return 0; } @@ -303,19 +338,19 @@ static int ggml_backend_cpu_x86_score() { score += 1<<7; #endif #ifdef GGML_AVX512_VBMI - if (!is.AVX512_VBMI()) { return 0; } + if (!is.AVX512_VBMI() || !os.ZMM()) { return 0; } score += 1<<8; #endif #ifdef GGML_AVX512_BF16 - if (!is.AVX512_BF16()) { return 0; } + if (!is.AVX512_BF16() || !os.ZMM()) { return 0; } score += 1<<9; #endif #ifdef GGML_AVX512_VNNI - if (!is.AVX512_VNNI()) { return 0; } + if (!is.AVX512_VNNI() || !os.ZMM()) { return 0; } score += 1<<10; #endif #ifdef GGML_AMX_INT8 - if (!is.AMX_INT8()) { return 0; } + if (!is.AMX_INT8() || !os.TMM()) { return 0; } score += 1<<11; #endif