diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 49d866e..d6c4d68 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -357,7 +357,7 @@ static HWY_NOINLINE void CrossEntropyLossGrad( } } -template typename WeightsT, +template typename WeightsT, template typename LayerT> void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index d9749c2..49efedd 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -222,7 +222,7 @@ void ApplyForwardLayer(const LayerT& weights, } } -template typename WeightsT, +template typename WeightsT, template typename LayerT> float CrossEntropyLossForwardPass(const std::vector& prompt, size_t context_size, diff --git a/gemma/common.cc b/gemma/common.cc index 7fff656..75d9282 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -30,15 +30,18 @@ namespace gcpp { const char* ParseModelTypeAndTraining(const std::string& model_flag, Model& model, ModelTraining& training) { - constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it", - "7b-it", "gr2b-it", "tiny"}; + constexpr const char* kModelFlags[] = { + "2b-pt", "7b-pt", "gr2b-pt", "2b-it", "7b-it", "gr2b-it", "tiny", + }; constexpr Model kModelTypes[] = { - Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, - Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY}; + Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, + Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY, + }; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, - ModelTraining::GEMMA_IT}; + ModelTraining::GEMMA_IT, + }; constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); static char kErrorMessageBuffer[kNum * 8 + 1024] = diff --git a/gemma/weights.cc b/gemma/weights.cc index a138474..d850cfd 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -43,7 +43,7 @@ struct LoadCompressedWeightsT { using CWeights = CompressedWeights; ByteStorageT c_weights_u8 = AllocateSizeof(); CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); - new (&c_weights->c_layer_ptrs) CompressedLayerPointers(pool); + new (c_weights) CWeights(pool); std::array scales; CacheLoader loader(weights); diff --git a/gemma/weights.h b/gemma/weights.h index d9f5c0f..8d21cf3 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -207,9 +207,15 @@ struct CompressedLayerPointers { std::array, TConfig::kLayers> c_layers; }; -template +template struct CompressedWeights { - // No ctor/dtor, allocated via AllocateAligned. + // Must be allocated via AllocateAligned and initialized with placement new. + void* operator new(size_t, void* addr) { return addr; } + void* operator new(size_t) = delete; + void* operator new[](size_t) = delete; + void operator delete(void*) = delete; + void operator delete[](void*) = delete; + using Weight = typename TConfig::Weight; using WeightF32OrInputT = @@ -224,6 +230,16 @@ struct CompressedWeights { // Must be last so that the other arrays remain aligned. CompressedLayerPointers c_layer_ptrs; + explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {} + + void ZeroInit() { + hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding)); + hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + hwy::ZeroBytes(GetLayer(i), sizeof(*GetLayer(i))); + } + } + const CompressedLayer* GetLayer(size_t layer) const { return c_layer_ptrs.c_layers[layer].get(); } @@ -260,7 +276,7 @@ struct AllocateCompressedWeights { using TWeights = CompressedWeights; ByteStorageT weights_u8 = AllocateSizeof(); TWeights* weights = reinterpret_cast(weights_u8.get()); - new (&weights->c_layer_ptrs) CompressedLayerPointers(pool); + new (weights) TWeights(pool); return weights_u8; } }; @@ -291,12 +307,7 @@ struct ZeroInitCompressedWeights { void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { CompressedWeights& w = *reinterpret_cast*>(weights.get()); - hwy::ZeroBytes(&w.embedder_input_embedding, - sizeof(w.embedder_input_embedding)); - hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); - for (int i = 0; i < TConfig::kLayers; ++i) { - hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); - } + w.ZeroInit(); } }; @@ -320,7 +331,7 @@ struct DeleteLayersPtrs { void operator()(ByteStorageT& weights_u8) const { auto* weights = reinterpret_cast*>(weights_u8.get()); - weights->c_layer_ptrs.~CompressedLayerPointers(); + weights->~CompressedWeights(); } }; @@ -358,10 +369,6 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); // ---------------------------------------------------------------------------- // Iterators -#define GEMMA_CALL_FUNC(name, member) \ - snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member) - // Calls func(name, float*, CompressedArray&) for each tensor. float* is null // if weights = null, which happens during the first call where we attempt to // load from cache. @@ -376,6 +383,10 @@ void ForEachTensor(const WeightsF* weights, func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr, c_weights.final_norm_scale); +#define GEMMA_CALL_FUNC(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member) + char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; @@ -418,9 +429,8 @@ void ForEachTensor(const WeightsF* weights, GEMMA_CALL_FUNC("attn_ob", attention_output_biases); } } -} - #undef GEMMA_CALL_FUNC +} // ForEachTensor #define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member) #define GEMMA_CALL_TOP_FUNC2(name, member) \