Refactor CompressedWeights.

PiperOrigin-RevId: 643934198
This commit is contained in:
The gemma.cpp Authors 2024-06-17 02:54:14 -07:00 committed by Copybara-Service
parent e0afdfa8fb
commit 7dbfa44794
5 changed files with 37 additions and 24 deletions

View File

@ -357,7 +357,7 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
}
}
template <typename TConfig, template<typename> typename WeightsT,
template <typename TConfig, template<typename...> typename WeightsT,
template<typename> typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const WeightsT<TConfig>& weights,

View File

@ -222,7 +222,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
}
}
template <typename TConfig, template<typename> typename WeightsT,
template <typename TConfig, template<typename...> typename WeightsT,
template<typename> typename LayerT>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,

View File

@ -30,15 +30,18 @@ namespace gcpp {
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr const char* kModelFlags[] = {
"2b-pt", "7b-pt", "gr2b-pt", "2b-it", "7b-it", "gr2b-it", "tiny",
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY,
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
ModelTraining::GEMMA_IT,
};
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024] =

View File

@ -43,7 +43,7 @@ struct LoadCompressedWeightsT {
using CWeights = CompressedWeights<TConfig>;
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
new (c_weights) CWeights(pool);
std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(weights);

View File

@ -207,9 +207,15 @@ struct CompressedLayerPointers {
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
};
template <class TConfig>
template <class TConfig, typename = void>
struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned.
// Must be allocated via AllocateAligned and initialized with placement new.
void* operator new(size_t, void* addr) { return addr; }
void* operator new(size_t) = delete;
void* operator new[](size_t) = delete;
void operator delete(void*) = delete;
void operator delete[](void*) = delete;
using Weight = typename TConfig::Weight;
using WeightF32OrInputT =
@ -224,6 +230,16 @@ struct CompressedWeights {
// Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs;
explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {}
void ZeroInit() {
hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding));
hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(GetLayer(i), sizeof(*GetLayer(i)));
}
}
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
return c_layer_ptrs.c_layers[layer].get();
}
@ -260,7 +276,7 @@ struct AllocateCompressedWeights {
using TWeights = CompressedWeights<TConfig>;
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
new (weights) TWeights(pool);
return weights_u8;
}
};
@ -291,12 +307,7 @@ struct ZeroInitCompressedWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
CompressedWeights<TConfig>& w =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights.get());
hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
}
w.ZeroInit();
}
};
@ -320,7 +331,7 @@ struct DeleteLayersPtrs {
void operator()(ByteStorageT& weights_u8) const {
auto* weights =
reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
weights->~CompressedWeights<TConfig>();
}
};
@ -358,10 +369,6 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
// ----------------------------------------------------------------------------
// Iterators
#define GEMMA_CALL_FUNC(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to
// load from cache.
@ -376,6 +383,10 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
c_weights.final_norm_scale);
#define GEMMA_CALL_FUNC(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
@ -418,9 +429,8 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
}
}
}
#undef GEMMA_CALL_FUNC
} // ForEachTensor
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
#define GEMMA_CALL_TOP_FUNC2(name, member) \