Adress review comments

This commit is contained in:
Zoltan Szabadka 2024-06-04 08:35:22 +00:00
parent 7e639856da
commit 8567978541
13 changed files with 128 additions and 117 deletions

View File

@ -60,6 +60,7 @@ struct ForwardPass {
} }
}; };
// Owns activations and undoes the type erasure of AllocateAligned.
template<typename T, typename TConfig> template<typename T, typename TConfig>
class ActivationsWrapper { class ActivationsWrapper {
using WrappedT = ForwardPass<T, TConfig>; using WrappedT = ForwardPass<T, TConfig>;

View File

@ -13,6 +13,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Implementation of the Vector-Jacobian Products (VJP) of the individual
// operations of the forward pass.
// Include guard for non-SIMD code. // Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
@ -56,7 +59,7 @@ void MatMulVJP(const std::array<float, kRows * kCols>& weights,
std::array<float, kRows * kCols>& grad_w, std::array<float, kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kCols float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
memset(grad_x, 0, num_tokens * kCols * sizeof(grad_x[0])); hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t voffs = pos * kRows; const size_t voffs = pos * kRows;
const size_t xoffs = pos * kCols; const size_t xoffs = pos * kCols;
@ -77,7 +80,7 @@ void MultiHeadMatMulVJP(
std::array<float, kHeads * kRows * kCols>& grad_w, std::array<float, kHeads * kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
memset(grad_x, 0, num_tokens * kHeads * kCols * sizeof(grad_x[0])); hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t j = 0; j < kRows; ++j) { for (size_t j = 0; j < kRows; ++j) {
for (size_t h = 0; h < kHeads; ++h) { for (size_t h = 0; h < kHeads; ++h) {
@ -154,11 +157,12 @@ static HWY_NOINLINE void RMSNormVJP(
static HWY_NOINLINE void InputEmbeddingVJP( static HWY_NOINLINE void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt, const float* weights, const std::vector<int>& prompt,
const float scaling, const float* HWY_RESTRICT backward, const float scaling, const float* HWY_RESTRICT v,
float* HWY_RESTRICT grad, size_t model_dim) { float* HWY_RESTRICT grad, size_t model_dim) {
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos]; int token = prompt[pos];
MulByConstAndAdd(scaling, backward + pos * model_dim, MulByConstAndAdd(scaling, v + pos * model_dim,
grad + token * model_dim, model_dim); grad + token * model_dim, model_dim);
} }
} }
@ -274,7 +278,7 @@ void LayerVJP(const Layer<float, TConfig>& weights,
} }
} }
for (int pos = 0; pos < num_tokens; ++pos) { for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv = float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(b_kv, kQKVDim, -pos); Rope(b_kv, kQKVDim, -pos);
@ -328,16 +332,17 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
static HWY_NOINLINE void CrossEntropyLossGrad( static HWY_NOINLINE void CrossEntropyLossGrad(
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
const Prompt& prompt, size_t vocab_size) { const Prompt& prompt, size_t vocab_size) {
HWY_ASSERT(!prompt.tokens.empty());
const float scaling = -1.0 / std::log(2.0); const float scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1; size_t num_tokens = prompt.tokens.size() - 1;
memset(grad, 0, num_tokens * vocab_size * sizeof(grad[0])); hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { for (size_t pos = 0; pos < num_tokens; ++pos) {
if (i + 1 < prompt.context_size) { if (pos + 1 < prompt.context_size) {
continue; continue;
} }
const int next_token = prompt.tokens[i + 1]; const int next_token = prompt.tokens[pos + 1];
grad[i * vocab_size + next_token] = grad[pos * vocab_size + next_token] =
scaling / x[i * vocab_size + next_token]; scaling / x[pos * vocab_size + next_token];
} }
} }

View File

@ -193,7 +193,8 @@ void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
template<typename T> template<typename T>
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling, void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
const T* dy, T* dw, size_t N) { const T* dy, T* dw, size_t N) {
for (size_t i = 0; i + 1 < tokens.size(); ++i) { const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i]; int token = tokens[i];
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N); MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
} }
@ -287,13 +288,14 @@ void SoftcapVJPT(const T* y, T* dy, size_t N) {
template<typename T> template<typename T>
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
T scaling = -1.0 / std::log(2.0); T scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1; const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
memset(dx, 0, V * num_tokens * sizeof(x[0])); memset(dx, 0, V * num_tokens * sizeof(x[0]));
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) { if (i + 1 < prompt.context_size) {
continue; continue;
} }
const int next_token = prompt.tokens[i + 1]; const int next_token = tokens[i + 1];
dx[i * V + next_token] = scaling / x[i * V + next_token]; dx[i * V + next_token] = scaling / x[i * V + next_token];
} }
} }
@ -307,7 +309,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kLayers = TConfig::kLayers;
const size_t num_tokens = prompt.tokens.size() - 1; const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize); kVocabSize);
@ -341,8 +344,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const T kEmbScaling = EmbeddingScaling(kModelDim); const T kEmbScaling = EmbeddingScaling(kModelDim);
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
prompt.tokens, kEmbScaling, tokens, kEmbScaling, backward.layers[0].input.data(),
backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim); grad.embedder_input_embedding.data(), kModelDim);
} }

View File

@ -83,13 +83,13 @@ void TestMatMulVJP() {
return DotT(dy.data(), c_y.data(), kTokens * kRows); return DotT(dy.data(), c_y.data(), kTokens * kRows);
}; };
memset(&grad, 0, sizeof(grad)); hwy::ZeroBytes(&grad, sizeof(grad));
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens, MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
grad, dx.data(), pool); grad, dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
memset(&grad_scalar, 0, sizeof(grad_scalar)); hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kRows, kCols, kTokens); dx_scalar.data(), kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
@ -128,13 +128,13 @@ void TestMultiHeadMatMulVJP() {
return DotT(dy.data(), c_y.data(), kTokens * kRows); return DotT(dy.data(), c_y.data(), kTokens * kRows);
}; };
memset(&grad, 0, sizeof(grad)); hwy::ZeroBytes(&grad, sizeof(grad));
MultiHeadMatMulVJP<kHeads, kCols, kRows>( MultiHeadMatMulVJP<kHeads, kCols, kRows>(
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
memset(&grad_scalar, 0, sizeof(grad_scalar)); hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kHeads, kRows, kCols, kTokens); dx_scalar.data(), kHeads, kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
@ -170,13 +170,13 @@ void TestRMSNormVJP() {
return DotT(dy.data(), c_y.data(), K * N); return DotT(dy.data(), c_y.data(), K * N);
}; };
memset(&grad, 0, sizeof(grad)); hwy::ZeroBytes(&grad, sizeof(grad));
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pool); dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
memset(&grad_scalar, 0, sizeof(grad_scalar)); hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), N, K); dx_scalar.data(), N, K);
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
@ -261,9 +261,7 @@ HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP); HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP); HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd); HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
#ifdef HWY_AFTER_TEST
HWY_AFTER_TEST(); HWY_AFTER_TEST();
#endif
} // namespace gcpp } // namespace gcpp

View File

@ -28,9 +28,6 @@
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.

View File

@ -19,8 +19,8 @@
#define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT #define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" // IWYU pragma: keep
#include "gemma/ops.h" #include "gemma/ops.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {

View File

@ -51,7 +51,8 @@ template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt, void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output, const float scaling, float* HWY_RESTRICT output,
size_t model_dim) { size_t model_dim) {
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos]; int token = prompt[pos];
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim); Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim); MulByConst(scaling, output + pos * model_dim, model_dim);
@ -74,8 +75,9 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
size_t context_size, size_t context_size,
size_t vocab_size, size_t vocab_size,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_ASSERT(!prompt.empty());
float loss = 0.0f; float loss = 0.0f;
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
if (pos + 1 < context_size) { if (pos + 1 < context_size) {
continue; // next token is part of context, don't try to predict it continue; // next token is part of context, don't try to predict it
} }
@ -271,8 +273,8 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize); LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
} }
memcpy(forward.probs.data(), forward.logits.data(), hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0])); num_tokens * kVocabSize * sizeof(forward.logits[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);

View File

@ -21,9 +21,9 @@
#define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT #define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" // IWYU pragma: keep
#include "gemma/forward-inl.h" #include "gemma/forward-inl.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {

View File

@ -116,7 +116,8 @@ void GatedGelu(const T* in, T* out, size_t N, size_t K) {
template<typename T> template<typename T>
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling, void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
T* y, size_t N) { T* y, size_t N) {
for (size_t i = 0; i + 1 < tokens.size(); ++i) { const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i]; int token = tokens[i];
memcpy(y + i * N, w + token * N, N * sizeof(y[0])); memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
MulByConstT(scaling, y + i * N, N); MulByConstT(scaling, y + i * N, N);
@ -230,11 +231,13 @@ void ApplyLayer(const Layer<T, TConfig>& weights,
template<typename T> template<typename T>
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
T loss = {}; T loss = {};
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) { if (i + 1 < prompt.context_size) {
continue; // next token is part of context, don't try to predict it continue; // next token is part of context, don't try to predict it
} }
const int next_token = prompt.tokens[i + 1]; const int next_token = tokens[i + 1];
loss += std::log(x[i * V + next_token]); loss += std::log(x[i * V + next_token]);
} }
T scaling = -1.0 / std::log(2.0); T scaling = -1.0 / std::log(2.0);
@ -248,10 +251,11 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kLayers = TConfig::kLayers;
const size_t num_tokens = prompt.tokens.size() - 1; const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(kModelDim); const T kEmbScaling = EmbeddingScaling(kModelDim);
InputEmbedding(weights.embedder_input_embedding.data(), prompt.tokens, InputEmbedding(weights.embedder_input_embedding.data(), tokens,
kEmbScaling, forward.layers[0].input.data(), kModelDim); kEmbScaling, forward.layers[0].input.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) { for (size_t layer = 0; layer < kLayers; ++layer) {

View File

@ -874,6 +874,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
m*theta_i. However in the Gemma implementation we choose to rotate m*theta_i. However in the Gemma implementation we choose to rotate
the pairs of dimensions v_{i} and v_{i + d//2} instead. the pairs of dimensions v_{i} and v_{i + d//2} instead.
pos parameter is deliberately an int because in the backward pass we
call this with negative values (for the VJP calculation we need the transpose
of this rotation matrix which is simply the same matrix with -pos parameter)
*/ */
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,

View File

@ -107,7 +107,7 @@ void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
} }
// Compute gradient with the finite difference method in the complex plane. // Compute gradient with the finite difference method in the complex plane.
// If f : R->R is the tested function and F : C->C is its extenstion on the // If f : R->R is the tested function and F : C->C is its extension on the
// complex plane so that F is complex differentiable in x, then // complex plane so that F is complex differentiable in x, then
// //
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x) // F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
@ -117,7 +117,7 @@ void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
// F'(x) ~= Imag(F(x + ih)) / h // F'(x) ~= Imag(F(x + ih)) / h
// //
// This method is more numerically stable than the real-valued finite difference // This method is more numerically stable than the real-valued finite difference
// method since we don't need to substract floating point numbers that are near // method since we don't need to subtract floating point numbers that are near
// to each other. // to each other.
template<typename T, typename U, size_t N, typename FUNC> template<typename T, typename U, size_t N, typename FUNC>
void TestGradient(const std::array<T, N>& grad, void TestGradient(const std::array<T, N>& grad,

View File

@ -18,6 +18,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/stats.h"
namespace gcpp { namespace gcpp {
@ -66,17 +67,12 @@ void ZeroInitWeights(Model model, ByteStorageT& weights,
namespace { namespace {
void LogVec(const char* name, const float* data, size_t len) { void LogVec(const char* name, const float* data, size_t len) {
float minval = std::numeric_limits<float>::max(); hwy::Stats stats;
float maxval = std::numeric_limits<float>::min();
double sum = 0.0f;
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
minval = std::min(minval, data[i]); stats.Notify(data[i]);
maxval = std::max(maxval, data[i]);
sum += data[i];
} }
float avg = sum / len;
printf("%-20s %12zu %13.10f %8.5f %13.10f\n", printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
name, len, minval, avg, maxval); name, len, stats.Min(), stats.Mean(), stats.Max());
} }
class WeightLogger { class WeightLogger {

View File

@ -127,136 +127,138 @@ ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
return weights_u8; return weights_u8;
} }
#define CALL_TOP_FUNC1(name, member) func(name, weights1.member) #define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
#define CALL_TOP_FUNC2(name, member) \ #define GEMMA_CALL_TOP_FUNC2(name, member) \
func(name, weights1.member, weights2.member) func(name, weights1.member, weights2.member)
#define CALL_TOP_FUNC3(name, member) \ #define GEMMA_CALL_TOP_FUNC3(name, member) \
func(name, weights1.member, weights2.member, weights3.member) func(name, weights1.member, weights2.member, weights3.member)
#define CALL_TOP_FUNC4(name, member) \ #define GEMMA_CALL_TOP_FUNC4(name, member) \
func(name, weights1.member, weights2.memeber, \ func(name, weights1.member, weights2.memeber, \
weights3.member, weights4.member) weights3.member, weights4.member)
#define CALL_LAYER_FUNC1(name, member) \ #define GEMMA_CALL_LAYER_FUNC1(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member) func(name_buf, layer1.member)
#define CALL_LAYER_FUNC2(name, member) \ #define GEMMA_CALL_LAYER_FUNC2(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member) func(name_buf, layer1.member, layer2.member)
#define CALL_LAYER_FUNC3(name, member) \ #define GEMMA_CALL_LAYER_FUNC3(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer3.member) func(name_buf, layer1.member, layer2.member, layer3.member)
#define CALL_LAYER_FUNC4(name, member) \ #define GEMMA_CALL_LAYER_FUNC4(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer4.member) func(name_buf, layer1.member, layer2.member, layer4.member)
#define CALL_ALL_LAYER_FUNC(N) \ #define GEMMA_CALL_ALL_LAYER_FUNC(N) \
if (type == LayerAttentionType::kGemma) { \ if (type == LayerAttentionType::kGemma) { \
CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \
CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \
} else { \ } else { \
CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \
CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \
CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \
CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \
CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \
CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \
CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \
CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \
CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \
CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \
CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \
} \ } \
CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
CALL_LAYER_FUNC ## N("linear_w", linear_w); \ GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
if (TConfig::kPostNormScale) { \ if (TConfig::kPostNormScale) { \
CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
} \ } \
CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \
if (TConfig::kFFBiases) { \ if (TConfig::kFFBiases) { \
CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \
CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \
} \ } \
if (TConfig::kSoftmaxAttnOutputBiases && \ if (TConfig::kSoftmaxAttnOutputBiases && \
type == LayerAttentionType::kGemma) { \ type == LayerAttentionType::kGemma) { \
CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
} }
template <typename T, typename TConfig, class Func> template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) { void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
CALL_TOP_FUNC1("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
CALL_ALL_LAYER_FUNC(1) GEMMA_CALL_ALL_LAYER_FUNC(1)
} }
} }
template <typename T, typename TConfig, class Func> template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) { void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
CALL_TOP_FUNC1("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
CALL_ALL_LAYER_FUNC(1) GEMMA_CALL_ALL_LAYER_FUNC(1)
} }
} }
template <typename T, typename TConfig, class Func> template <typename T, typename TConfig, class Func>
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1, void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
Weights<T, TConfig>& weights2) { Weights<T, TConfig>& weights2) {
CALL_TOP_FUNC2("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
CALL_TOP_FUNC2("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx); LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
CALL_ALL_LAYER_FUNC(2) GEMMA_CALL_ALL_LAYER_FUNC(2)
} }
} }
#undef CALL_TOP_FUNC1 #undef GEMMA_CALL_TOP_FUNC1
#undef CALL_TOP_FUNC2 #undef GEMMA_CALL_TOP_FUNC2
#undef CALL_TOP_FUNC3 #undef GEMMA_CALL_TOP_FUNC3
#undef CALL_TOP_FUNC4 #undef GEMMA_CALL_TOP_FUNC4
#undef CALL_LAYER_FUNC1 #undef GEMMA_CALL_LAYER_FUNC1
#undef CALL_LAYER_FUNC2 #undef GEMMA_CALL_LAYER_FUNC2
#undef CALL_LAYER_FUNC3 #undef GEMMA_CALL_LAYER_FUNC3
#undef CALL_LAYER_FUNC4 #undef GEMMA_CALL_LAYER_FUNC4
#undef CALL_ALL_LAYER_FUNC #undef GEMMA_CALL_ALL_LAYER_FUNC
template<typename T, typename TConfig> template<typename T, typename TConfig>
void ZeroInit(Weights<T, TConfig>& w) { void ZeroInit(Weights<T, TConfig>& w) {
memset(&w.embedder_input_embedding, 0, sizeof(w.embedder_input_embedding)); hwy::ZeroBytes(&w.embedder_input_embedding,
memset(&w.final_norm_scale, 0, sizeof(w.final_norm_scale)); sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) { for (int i = 0; i < TConfig::kLayers; ++i) {
memset(w.GetLayer(i), 0, sizeof(*w.GetLayer(i))); hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
} }
} }
template<typename T, typename TConfig> template<typename T, typename TConfig>
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) { void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
memcpy(&dst.embedder_input_embedding, &src.embedder_input_embedding, hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
sizeof(src.embedder_input_embedding)); sizeof(src.embedder_input_embedding));
memcpy(&dst.final_norm_scale, &src.final_norm_scale, hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
sizeof(src.final_norm_scale)); sizeof(src.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) { for (int i = 0; i < TConfig::kLayers; ++i) {
memcpy(dst.GetLayer(i), src.GetLayer(i), sizeof(*dst.GetLayer(i))); hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i)));
} }
} }
// Owns weights and undoes the type erasure of AllocateWeights.
template<typename T, typename TConfig> template<typename T, typename TConfig>
class WeightsWrapper { class WeightsWrapper {
public: public: