diff --git a/gemma/activations.h b/gemma/activations.h index e72f3d6..9a43344 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -60,6 +60,7 @@ struct ForwardPass { } }; +// Owns activations and undoes the type erasure of AllocateAligned. template class ActivationsWrapper { using WrappedT = ForwardPass; diff --git a/gemma/backward-inl.h b/gemma/backward-inl.h index 53cba96..cbeb946 100644 --- a/gemma/backward-inl.h +++ b/gemma/backward-inl.h @@ -13,6 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Implementation of the Vector-Jacobian Products (VJP) of the individual +// operations of the forward pass. + // Include guard for non-SIMD code. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ @@ -56,7 +59,7 @@ void MatMulVJP(const std::array& weights, std::array& grad_w, float* HWY_RESTRICT grad_x, // num_tokens * kCols hwy::ThreadPool& pool) { - memset(grad_x, 0, num_tokens * kCols * sizeof(grad_x[0])); + hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t voffs = pos * kRows; const size_t xoffs = pos * kCols; @@ -77,7 +80,7 @@ void MultiHeadMatMulVJP( std::array& grad_w, float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols hwy::ThreadPool& pool) { - memset(grad_x, 0, num_tokens * kHeads * kCols * sizeof(grad_x[0])); + hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t j = 0; j < kRows; ++j) { for (size_t h = 0; h < kHeads; ++h) { @@ -154,11 +157,12 @@ static HWY_NOINLINE void RMSNormVJP( static HWY_NOINLINE void InputEmbeddingVJP( const float* weights, const std::vector& prompt, - const float scaling, const float* HWY_RESTRICT backward, + const float scaling, const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) { - for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { + HWY_ASSERT(!prompt.empty()); + for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; - MulByConstAndAdd(scaling, backward + pos * model_dim, + MulByConstAndAdd(scaling, v + pos * model_dim, grad + token * model_dim, model_dim); } } @@ -274,7 +278,7 @@ void LayerVJP(const Layer& weights, } } - for (int pos = 0; pos < num_tokens; ++pos) { + for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; Rope(b_kv, kQKVDim, -pos); @@ -328,16 +332,17 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, static HWY_NOINLINE void CrossEntropyLossGrad( const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, const Prompt& prompt, size_t vocab_size) { + HWY_ASSERT(!prompt.tokens.empty()); const float scaling = -1.0 / std::log(2.0); size_t num_tokens = prompt.tokens.size() - 1; - memset(grad, 0, num_tokens * vocab_size * sizeof(grad[0])); - for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { - if (i + 1 < prompt.context_size) { + hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + if (pos + 1 < prompt.context_size) { continue; } - const int next_token = prompt.tokens[i + 1]; - grad[i * vocab_size + next_token] = - scaling / x[i * vocab_size + next_token]; + const int next_token = prompt.tokens[pos + 1]; + grad[pos * vocab_size + next_token] = + scaling / x[pos * vocab_size + next_token]; } } diff --git a/gemma/backward_scalar.h b/gemma/backward_scalar.h index b7b42b4..5e5a93b 100644 --- a/gemma/backward_scalar.h +++ b/gemma/backward_scalar.h @@ -193,7 +193,8 @@ void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, template void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, const T* dy, T* dw, size_t N) { - for (size_t i = 0; i + 1 < tokens.size(); ++i) { + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + for (size_t i = 0; i < num_tokens; ++i) { int token = tokens[i]; MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N); } @@ -287,13 +288,14 @@ void SoftcapVJPT(const T* y, T* dy, size_t N) { template void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { T scaling = -1.0 / std::log(2.0); - size_t num_tokens = prompt.tokens.size() - 1; + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; memset(dx, 0, V * num_tokens * sizeof(x[0])); - for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { + for (size_t i = 0; i < num_tokens; ++i) { if (i + 1 < prompt.context_size) { continue; } - const int next_token = prompt.tokens[i + 1]; + const int next_token = tokens[i + 1]; dx[i * V + next_token] = scaling / x[i * V + next_token]; } } @@ -307,7 +309,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kLayers = TConfig::kLayers; - const size_t num_tokens = prompt.tokens.size() - 1; + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, kVocabSize); @@ -341,8 +344,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const T kEmbScaling = EmbeddingScaling(kModelDim); InputEmbeddingVJPT(weights.embedder_input_embedding.data(), - prompt.tokens, kEmbScaling, - backward.layers[0].input.data(), + tokens, kEmbScaling, backward.layers[0].input.data(), grad.embedder_input_embedding.data(), kModelDim); } diff --git a/gemma/backward_test.cc b/gemma/backward_test.cc index 50918a0..cdeac2e 100644 --- a/gemma/backward_test.cc +++ b/gemma/backward_test.cc @@ -83,13 +83,13 @@ void TestMatMulVJP() { return DotT(dy.data(), c_y.data(), kTokens * kRows); }; - memset(&grad, 0, sizeof(grad)); + hwy::ZeroBytes(&grad, sizeof(grad)); MatMulVJP(weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); - memset(&grad_scalar, 0, sizeof(grad_scalar)); + hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), kRows, kCols, kTokens); TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); @@ -128,13 +128,13 @@ void TestMultiHeadMatMulVJP() { return DotT(dy.data(), c_y.data(), kTokens * kRows); }; - memset(&grad, 0, sizeof(grad)); + hwy::ZeroBytes(&grad, sizeof(grad)); MultiHeadMatMulVJP( weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); - memset(&grad_scalar, 0, sizeof(grad_scalar)); + hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), kHeads, kRows, kCols, kTokens); TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); @@ -170,13 +170,13 @@ void TestRMSNormVJP() { return DotT(dy.data(), c_y.data(), K * N); }; - memset(&grad, 0, sizeof(grad)); + hwy::ZeroBytes(&grad, sizeof(grad)); RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), dx.data(), pool); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); - memset(&grad_scalar, 0, sizeof(grad_scalar)); + hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), N, K); TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); @@ -261,9 +261,7 @@ HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP); HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP); HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP); HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd); -#ifdef HWY_AFTER_TEST HWY_AFTER_TEST(); -#endif } // namespace gcpp diff --git a/gemma/common-inl.h b/gemma/common-inl.h index c7f53bc..ac39d73 100644 --- a/gemma/common-inl.h +++ b/gemma/common-inl.h @@ -28,9 +28,6 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -namespace gcpp { -} // namespace gcpp - #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ // Include guard for (potentially) SIMD code. diff --git a/gemma/common_scalar.cc b/gemma/common_scalar.cc index 9a82c1d..bbf201c 100644 --- a/gemma/common_scalar.cc +++ b/gemma/common_scalar.cc @@ -19,8 +19,8 @@ #define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" // IWYU pragma: keep #include "gemma/ops.h" -#include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/gemma/forward-inl.h b/gemma/forward-inl.h index 19f603f..4b7cdf1 100644 --- a/gemma/forward-inl.h +++ b/gemma/forward-inl.h @@ -51,7 +51,8 @@ template void InputEmbedding(const ArrayT& weights, const std::vector& prompt, const float scaling, float* HWY_RESTRICT output, size_t model_dim) { - for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { + HWY_ASSERT(!prompt.empty()); + for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; Decompress(weights, token * model_dim, output + pos * model_dim, model_dim); MulByConst(scaling, output + pos * model_dim, model_dim); @@ -74,8 +75,9 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, size_t context_size, size_t vocab_size, hwy::ThreadPool& pool) { + HWY_ASSERT(!prompt.empty()); float loss = 0.0f; - for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { + for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { if (pos + 1 < context_size) { continue; // next token is part of context, don't try to predict it } @@ -271,8 +273,8 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize); } - memcpy(forward.probs.data(), forward.logits.data(), - num_tokens * kVocabSize * sizeof(forward.logits[0])); + hwy::CopyBytes(forward.logits.data(), forward.probs.data(), + num_tokens * kVocabSize * sizeof(forward.logits[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); diff --git a/gemma/forward.cc b/gemma/forward.cc index 1ebabfd..9699a34 100644 --- a/gemma/forward.cc +++ b/gemma/forward.cc @@ -21,9 +21,9 @@ #define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" // IWYU pragma: keep #include "gemma/forward-inl.h" #include "gemma/weights.h" -#include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/gemma/forward_scalar.h b/gemma/forward_scalar.h index 0a62107..59c05b6 100644 --- a/gemma/forward_scalar.h +++ b/gemma/forward_scalar.h @@ -116,7 +116,8 @@ void GatedGelu(const T* in, T* out, size_t N, size_t K) { template void InputEmbedding(const T* w, const std::vector& tokens, T scaling, T* y, size_t N) { - for (size_t i = 0; i + 1 < tokens.size(); ++i) { + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + for (size_t i = 0; i < num_tokens; ++i) { int token = tokens[i]; memcpy(y + i * N, w + token * N, N * sizeof(y[0])); MulByConstT(scaling, y + i * N, N); @@ -230,11 +231,13 @@ void ApplyLayer(const Layer& weights, template T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { T loss = {}; - for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + for (size_t i = 0; i < num_tokens; ++i) { if (i + 1 < prompt.context_size) { continue; // next token is part of context, don't try to predict it } - const int next_token = prompt.tokens[i + 1]; + const int next_token = tokens[i + 1]; loss += std::log(x[i * V + next_token]); } T scaling = -1.0 / std::log(2.0); @@ -248,10 +251,11 @@ T CrossEntropyLossForwardPass(const Prompt& prompt, static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kLayers = TConfig::kLayers; - const size_t num_tokens = prompt.tokens.size() - 1; + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; const T kEmbScaling = EmbeddingScaling(kModelDim); - InputEmbedding(weights.embedder_input_embedding.data(), prompt.tokens, + InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling, forward.layers[0].input.data(), kModelDim); for (size_t layer = 0; layer < kLayers; ++layer) { diff --git a/gemma/ops.h b/gemma/ops.h index 5bffa84..7ac73df 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -874,6 +874,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle m*theta_i. However in the Gemma implementation we choose to rotate the pairs of dimensions v_{i} and v_{i + d//2} instead. + + pos parameter is deliberately an int because in the backward pass we + call this with negative values (for the VJP calculation we need the transpose + of this rotation matrix which is simply the same matrix with -pos parameter) */ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, diff --git a/gemma/test_util.h b/gemma/test_util.h index e9cd533..939411d 100644 --- a/gemma/test_util.h +++ b/gemma/test_util.h @@ -107,7 +107,7 @@ void TestNear(const std::array& actual, const std::array& expected, } // Compute gradient with the finite difference method in the complex plane. -// If f : R->R is the tested function and F : C->C is its extenstion on the +// If f : R->R is the tested function and F : C->C is its extension on the // complex plane so that F is complex differentiable in x, then // // F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x) @@ -117,7 +117,7 @@ void TestNear(const std::array& actual, const std::array& expected, // F'(x) ~= Imag(F(x + ih)) / h // // This method is more numerically stable than the real-valued finite difference -// method since we don't need to substract floating point numbers that are near +// method since we don't need to subtract floating point numbers that are near // to each other. template void TestGradient(const std::array& grad, diff --git a/gemma/weights.cc b/gemma/weights.cc index c6387f7..aa302e2 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -18,6 +18,7 @@ #include "gemma/common.h" #include "gemma/configs.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/stats.h" namespace gcpp { @@ -66,17 +67,12 @@ void ZeroInitWeights(Model model, ByteStorageT& weights, namespace { void LogVec(const char* name, const float* data, size_t len) { - float minval = std::numeric_limits::max(); - float maxval = std::numeric_limits::min(); - double sum = 0.0f; + hwy::Stats stats; for (size_t i = 0; i < len; ++i) { - minval = std::min(minval, data[i]); - maxval = std::max(maxval, data[i]); - sum += data[i]; + stats.Notify(data[i]); } - float avg = sum / len; printf("%-20s %12zu %13.10f %8.5f %13.10f\n", - name, len, minval, avg, maxval); + name, len, stats.Min(), stats.Mean(), stats.Max()); } class WeightLogger { diff --git a/gemma/weights.h b/gemma/weights.h index 9777552..8d25759 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -127,136 +127,138 @@ ByteStorageT AllocateWeights(hwy::ThreadPool& pool) { return weights_u8; } -#define CALL_TOP_FUNC1(name, member) func(name, weights1.member) -#define CALL_TOP_FUNC2(name, member) \ +#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member) +#define GEMMA_CALL_TOP_FUNC2(name, member) \ func(name, weights1.member, weights2.member) -#define CALL_TOP_FUNC3(name, member) \ +#define GEMMA_CALL_TOP_FUNC3(name, member) \ func(name, weights1.member, weights2.member, weights3.member) -#define CALL_TOP_FUNC4(name, member) \ +#define GEMMA_CALL_TOP_FUNC4(name, member) \ func(name, weights1.member, weights2.memeber, \ weights3.member, weights4.member) -#define CALL_LAYER_FUNC1(name, member) \ +#define GEMMA_CALL_LAYER_FUNC1(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member) -#define CALL_LAYER_FUNC2(name, member) \ +#define GEMMA_CALL_LAYER_FUNC2(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member, layer2.member) -#define CALL_LAYER_FUNC3(name, member) \ +#define GEMMA_CALL_LAYER_FUNC3(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member, layer2.member, layer3.member) -#define CALL_LAYER_FUNC4(name, member) \ +#define GEMMA_CALL_LAYER_FUNC4(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member, layer2.member, layer4.member) -#define CALL_ALL_LAYER_FUNC(N) \ - if (type == LayerAttentionType::kGemma) { \ - CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ - CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ - } else { \ - CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ - CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ - CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ - CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ - CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ - CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ - CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ - CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ - CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ - CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ - CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ - } \ - CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ - CALL_LAYER_FUNC ## N("linear_w", linear_w); \ - CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ - if (TConfig::kPostNormScale) { \ - CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ - CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ - } \ - CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ - if (TConfig::kFFBiases) { \ - CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ - CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ - } \ - if (TConfig::kSoftmaxAttnOutputBiases && \ - type == LayerAttentionType::kGemma) { \ - CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ +#define GEMMA_CALL_ALL_LAYER_FUNC(N) \ + if (type == LayerAttentionType::kGemma) { \ + GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ + GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ + } else { \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ + } \ + GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ + GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ + GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ + if (TConfig::kPostNormScale) { \ + GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ + GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ + } \ + GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ + if (TConfig::kFFBiases) { \ + GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ + } \ + if (TConfig::kSoftmaxAttnOutputBiases && \ + type == LayerAttentionType::kGemma) { \ + GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ } template void ForEachTensor1(Func& func, const Weights& weights1) { - CALL_TOP_FUNC1("embedding", embedder_input_embedding); - CALL_TOP_FUNC1("final_norm", final_norm_scale); + GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const LayerF& layer1 = *weights1.GetLayer(idx); - CALL_ALL_LAYER_FUNC(1) + GEMMA_CALL_ALL_LAYER_FUNC(1) } } template void ForEachTensor1(Func& func, Weights& weights1) { - CALL_TOP_FUNC1("embedding", embedder_input_embedding); - CALL_TOP_FUNC1("final_norm", final_norm_scale); + GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); LayerF& layer1 = *weights1.GetLayer(idx); - CALL_ALL_LAYER_FUNC(1) + GEMMA_CALL_ALL_LAYER_FUNC(1) } } template void ForEachTensor2(Func& func, const Weights& weights1, Weights& weights2) { - CALL_TOP_FUNC2("embedding", embedder_input_embedding); - CALL_TOP_FUNC2("final_norm", final_norm_scale); + GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const LayerF& layer1 = *weights1.GetLayer(idx); LayerF& layer2 = *weights2.GetLayer(idx); - CALL_ALL_LAYER_FUNC(2) + GEMMA_CALL_ALL_LAYER_FUNC(2) } } -#undef CALL_TOP_FUNC1 -#undef CALL_TOP_FUNC2 -#undef CALL_TOP_FUNC3 -#undef CALL_TOP_FUNC4 -#undef CALL_LAYER_FUNC1 -#undef CALL_LAYER_FUNC2 -#undef CALL_LAYER_FUNC3 -#undef CALL_LAYER_FUNC4 -#undef CALL_ALL_LAYER_FUNC +#undef GEMMA_CALL_TOP_FUNC1 +#undef GEMMA_CALL_TOP_FUNC2 +#undef GEMMA_CALL_TOP_FUNC3 +#undef GEMMA_CALL_TOP_FUNC4 +#undef GEMMA_CALL_LAYER_FUNC1 +#undef GEMMA_CALL_LAYER_FUNC2 +#undef GEMMA_CALL_LAYER_FUNC3 +#undef GEMMA_CALL_LAYER_FUNC4 +#undef GEMMA_CALL_ALL_LAYER_FUNC template void ZeroInit(Weights& w) { - memset(&w.embedder_input_embedding, 0, sizeof(w.embedder_input_embedding)); - memset(&w.final_norm_scale, 0, sizeof(w.final_norm_scale)); + hwy::ZeroBytes(&w.embedder_input_embedding, + sizeof(w.embedder_input_embedding)); + hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); for (int i = 0; i < TConfig::kLayers; ++i) { - memset(w.GetLayer(i), 0, sizeof(*w.GetLayer(i))); + hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); } } template void Copy(Weights& dst, const Weights& src) { - memcpy(&dst.embedder_input_embedding, &src.embedder_input_embedding, - sizeof(src.embedder_input_embedding)); - memcpy(&dst.final_norm_scale, &src.final_norm_scale, - sizeof(src.final_norm_scale)); + hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, + sizeof(src.embedder_input_embedding)); + hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, + sizeof(src.final_norm_scale)); for (int i = 0; i < TConfig::kLayers; ++i) { - memcpy(dst.GetLayer(i), src.GetLayer(i), sizeof(*dst.GetLayer(i))); + hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i))); } } +// Owns weights and undoes the type erasure of AllocateWeights. template class WeightsWrapper { public: