mirror of https://github.com/google/gemma.cpp.git
Internal change plus add U8 type, check MatPtrT type at compile time
PiperOrigin-RevId: 867582875
This commit is contained in:
parent
7c19b31c66
commit
56fa6e4839
|
|
@ -229,12 +229,14 @@ enum class Type {
|
|||
kU32,
|
||||
kU64,
|
||||
kI8,
|
||||
kU16
|
||||
kU16,
|
||||
kU8,
|
||||
};
|
||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||
// change, though new ones may be added.
|
||||
static constexpr const char* kTypeStrings[] = {
|
||||
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8", "u16"};
|
||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "u32", "u64",
|
||||
"i8", "u16", "u8"};
|
||||
static constexpr size_t kNumTypes =
|
||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||
static constexpr size_t kTypeBits[] = {
|
||||
|
|
@ -248,6 +250,7 @@ static constexpr size_t kTypeBits[] = {
|
|||
8 * sizeof(uint64_t),
|
||||
8 * sizeof(I8Stream),
|
||||
8 * sizeof(uint16_t),
|
||||
8 * sizeof(uint8_t),
|
||||
};
|
||||
|
||||
static inline bool EnumValid(Type type) {
|
||||
|
|
@ -256,7 +259,7 @@ static inline bool EnumValid(Type type) {
|
|||
|
||||
// Returns a Type enum for the type of the template parameter.
|
||||
template <typename PackedT>
|
||||
Type TypeEnum() {
|
||||
constexpr Type TypeEnum() {
|
||||
using Packed = hwy::RemoveCvRef<PackedT>;
|
||||
if constexpr (hwy::IsSame<Packed, float>()) {
|
||||
return Type::kF32;
|
||||
|
|
@ -276,8 +279,9 @@ Type TypeEnum() {
|
|||
return Type::kI8;
|
||||
} else if constexpr (hwy::IsSame<Packed, uint16_t>()) {
|
||||
return Type::kU16;
|
||||
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
|
||||
return Type::kU8;
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
return Type::kUnknown;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -291,6 +291,8 @@ template <typename MatT>
|
|||
class MatPtrT : public MatPtr {
|
||||
public:
|
||||
using T = MatT;
|
||||
static_assert(TypeEnum<MatT>() != Type::kUnknown,
|
||||
"Must only use with supported MatT.");
|
||||
|
||||
// Default constructor for use with uninitialized views.
|
||||
MatPtrT() = default;
|
||||
|
|
|
|||
Loading…
Reference in New Issue