Internal change plus add U8 type, check MatPtrT type at compile time

PiperOrigin-RevId: 867582875
This commit is contained in:
Jan Wassenberg 2026-02-09 06:53:40 -08:00 committed by Copybara-Service
parent 7c19b31c66
commit 56fa6e4839
2 changed files with 11 additions and 5 deletions

View File

@ -229,12 +229,14 @@ enum class Type {
kU32,
kU64,
kI8,
kU16
kU16,
kU8,
};
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8", "u16"};
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "u32", "u64",
"i8", "u16", "u8"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
@ -248,6 +250,7 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(uint64_t),
8 * sizeof(I8Stream),
8 * sizeof(uint16_t),
8 * sizeof(uint8_t),
};
static inline bool EnumValid(Type type) {
@ -256,7 +259,7 @@ static inline bool EnumValid(Type type) {
// Returns a Type enum for the type of the template parameter.
template <typename PackedT>
Type TypeEnum() {
constexpr Type TypeEnum() {
using Packed = hwy::RemoveCvRef<PackedT>;
if constexpr (hwy::IsSame<Packed, float>()) {
return Type::kF32;
@ -276,8 +279,9 @@ Type TypeEnum() {
return Type::kI8;
} else if constexpr (hwy::IsSame<Packed, uint16_t>()) {
return Type::kU16;
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
return Type::kU8;
} else {
HWY_DASSERT(false);
return Type::kUnknown;
}
}

View File

@ -291,6 +291,8 @@ template <typename MatT>
class MatPtrT : public MatPtr {
public:
using T = MatT;
static_assert(TypeEnum<MatT>() != Type::kUnknown,
"Must only use with supported MatT.");
// Default constructor for use with uninitialized views.
MatPtrT() = default;