Fix NibbleCodec for AVX3_{ZEN4,DL,SPR}

PiperOrigin-RevId: 831002073
This commit is contained in:
The gemma.cpp Authors 2025-11-11 11:30:45 -08:00 committed by Copybara-Service
parent 3e18db17f4
commit 7c1656f2fc
1 changed files with 5 additions and 2 deletions

View File

@ -480,9 +480,12 @@ class NibbleCodec {
static_assert(kHalf <= 1); static_assert(kHalf <= 1);
const size_t N = hn::Lanes(d8); const size_t N = hn::Lanes(d8);
constexpr size_t kMaxN = hn::MaxLanes(d8); constexpr size_t kMaxN = hn::MaxLanes(d8);
constexpr bool kPermuteAcrossBlocks =
HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86;
// For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of // For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of
// bounds for TableLookupBytes. We instead BroadcastBlock<1> there. // bounds for TableLookupBytes. We instead BroadcastBlock<1> there.
constexpr uint8_t kAdd = kMaxN < 64 ? kHalf * kMaxN / 4 : 0; constexpr uint8_t kAdd =
kMaxN < 64 || kPermuteAcrossBlocks ? kHalf * kMaxN / 4 : 0;
// The only performance-portable op to replicate bytes is TableLookupBytes, // The only performance-portable op to replicate bytes is TableLookupBytes,
// but this only works if vectors are 128-bit or we first BroadcastBlock, // but this only works if vectors are 128-bit or we first BroadcastBlock,
// which only works for <= 512-bit vectors. For scalable vectors, we // which only works for <= 512-bit vectors. For scalable vectors, we
@ -506,7 +509,7 @@ class NibbleCodec {
} else if constexpr (kMaxN <= 16) { // <= 128-bit } else if constexpr (kMaxN <= 16) { // <= 128-bit
// No BroadcastBlock, we anyway only have one block. // No BroadcastBlock, we anyway only have one block.
return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
} else if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { } else if constexpr (kPermuteAcrossBlocks) {
// No BroadcastBlock, can directly permute across blocks. // No BroadcastBlock, can directly permute across blocks.
return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4));
} else { // 256..512-bit, no efficient TableLookupLanes } else { // 256..512-bit, no efficient TableLookupLanes