mirror of https://github.com/google/gemma.cpp.git
Adress review comments
This commit is contained in:
parent
7e639856da
commit
8567978541
|
|
@ -60,6 +60,7 @@ struct ForwardPass {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Owns activations and undoes the type erasure of AllocateAligned.
|
||||||
template<typename T, typename TConfig>
|
template<typename T, typename TConfig>
|
||||||
class ActivationsWrapper {
|
class ActivationsWrapper {
|
||||||
using WrappedT = ForwardPass<T, TConfig>;
|
using WrappedT = ForwardPass<T, TConfig>;
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,9 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Implementation of the Vector-Jacobian Products (VJP) of the individual
|
||||||
|
// operations of the forward pass.
|
||||||
|
|
||||||
// Include guard for non-SIMD code.
|
// Include guard for non-SIMD code.
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||||
|
|
@ -56,7 +59,7 @@ void MatMulVJP(const std::array<float, kRows * kCols>& weights,
|
||||||
std::array<float, kRows * kCols>& grad_w,
|
std::array<float, kRows * kCols>& grad_w,
|
||||||
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
memset(grad_x, 0, num_tokens * kCols * sizeof(grad_x[0]));
|
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t voffs = pos * kRows;
|
const size_t voffs = pos * kRows;
|
||||||
const size_t xoffs = pos * kCols;
|
const size_t xoffs = pos * kCols;
|
||||||
|
|
@ -77,7 +80,7 @@ void MultiHeadMatMulVJP(
|
||||||
std::array<float, kHeads * kRows * kCols>& grad_w,
|
std::array<float, kHeads * kRows * kCols>& grad_w,
|
||||||
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
|
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
memset(grad_x, 0, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
|
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
for (size_t j = 0; j < kRows; ++j) {
|
for (size_t j = 0; j < kRows; ++j) {
|
||||||
for (size_t h = 0; h < kHeads; ++h) {
|
for (size_t h = 0; h < kHeads; ++h) {
|
||||||
|
|
@ -154,11 +157,12 @@ static HWY_NOINLINE void RMSNormVJP(
|
||||||
|
|
||||||
static HWY_NOINLINE void InputEmbeddingVJP(
|
static HWY_NOINLINE void InputEmbeddingVJP(
|
||||||
const float* weights, const std::vector<int>& prompt,
|
const float* weights, const std::vector<int>& prompt,
|
||||||
const float scaling, const float* HWY_RESTRICT backward,
|
const float scaling, const float* HWY_RESTRICT v,
|
||||||
float* HWY_RESTRICT grad, size_t model_dim) {
|
float* HWY_RESTRICT grad, size_t model_dim) {
|
||||||
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) {
|
HWY_ASSERT(!prompt.empty());
|
||||||
|
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||||
int token = prompt[pos];
|
int token = prompt[pos];
|
||||||
MulByConstAndAdd(scaling, backward + pos * model_dim,
|
MulByConstAndAdd(scaling, v + pos * model_dim,
|
||||||
grad + token * model_dim, model_dim);
|
grad + token * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -274,7 +278,7 @@ void LayerVJP(const Layer<float, TConfig>& weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||||
float* HWY_RESTRICT b_kv =
|
float* HWY_RESTRICT b_kv =
|
||||||
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||||
Rope(b_kv, kQKVDim, -pos);
|
Rope(b_kv, kQKVDim, -pos);
|
||||||
|
|
@ -328,16 +332,17 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
||||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||||
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
|
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
|
||||||
const Prompt& prompt, size_t vocab_size) {
|
const Prompt& prompt, size_t vocab_size) {
|
||||||
|
HWY_ASSERT(!prompt.tokens.empty());
|
||||||
const float scaling = -1.0 / std::log(2.0);
|
const float scaling = -1.0 / std::log(2.0);
|
||||||
size_t num_tokens = prompt.tokens.size() - 1;
|
size_t num_tokens = prompt.tokens.size() - 1;
|
||||||
memset(grad, 0, num_tokens * vocab_size * sizeof(grad[0]));
|
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
|
||||||
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
if (i + 1 < prompt.context_size) {
|
if (pos + 1 < prompt.context_size) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const int next_token = prompt.tokens[i + 1];
|
const int next_token = prompt.tokens[pos + 1];
|
||||||
grad[i * vocab_size + next_token] =
|
grad[pos * vocab_size + next_token] =
|
||||||
scaling / x[i * vocab_size + next_token];
|
scaling / x[pos * vocab_size + next_token];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -193,7 +193,8 @@ void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
|
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
|
||||||
const T* dy, T* dw, size_t N) {
|
const T* dy, T* dw, size_t N) {
|
||||||
for (size_t i = 0; i + 1 < tokens.size(); ++i) {
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
int token = tokens[i];
|
int token = tokens[i];
|
||||||
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
|
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
|
||||||
}
|
}
|
||||||
|
|
@ -287,13 +288,14 @@ void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
|
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
|
||||||
T scaling = -1.0 / std::log(2.0);
|
T scaling = -1.0 / std::log(2.0);
|
||||||
size_t num_tokens = prompt.tokens.size() - 1;
|
const std::vector<int> tokens = prompt.tokens;
|
||||||
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
memset(dx, 0, V * num_tokens * sizeof(x[0]));
|
memset(dx, 0, V * num_tokens * sizeof(x[0]));
|
||||||
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) {
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
if (i + 1 < prompt.context_size) {
|
if (i + 1 < prompt.context_size) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const int next_token = prompt.tokens[i + 1];
|
const int next_token = tokens[i + 1];
|
||||||
dx[i * V + next_token] = scaling / x[i * V + next_token];
|
dx[i * V + next_token] = scaling / x[i * V + next_token];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -307,7 +309,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
static constexpr size_t kLayers = TConfig::kLayers;
|
||||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
const std::vector<int> tokens = prompt.tokens;
|
||||||
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
|
||||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||||
kVocabSize);
|
kVocabSize);
|
||||||
|
|
@ -341,8 +344,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
|
|
||||||
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
||||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
|
InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
|
||||||
prompt.tokens, kEmbScaling,
|
tokens, kEmbScaling, backward.layers[0].input.data(),
|
||||||
backward.layers[0].input.data(),
|
|
||||||
grad.embedder_input_embedding.data(), kModelDim);
|
grad.embedder_input_embedding.data(), kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,13 +83,13 @@ void TestMatMulVJP() {
|
||||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||||
};
|
};
|
||||||
|
|
||||||
memset(&grad, 0, sizeof(grad));
|
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||||
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
|
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
|
||||||
grad, dx.data(), pool);
|
grad, dx.data(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
||||||
memset(&grad_scalar, 0, sizeof(grad_scalar));
|
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
|
||||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||||
dx_scalar.data(), kRows, kCols, kTokens);
|
dx_scalar.data(), kRows, kCols, kTokens);
|
||||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
@ -128,13 +128,13 @@ void TestMultiHeadMatMulVJP() {
|
||||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||||
};
|
};
|
||||||
|
|
||||||
memset(&grad, 0, sizeof(grad));
|
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||||
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
|
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
|
||||||
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
|
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
||||||
memset(&grad_scalar, 0, sizeof(grad_scalar));
|
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
|
||||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||||
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
||||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
@ -170,13 +170,13 @@ void TestRMSNormVJP() {
|
||||||
return DotT(dy.data(), c_y.data(), K * N);
|
return DotT(dy.data(), c_y.data(), K * N);
|
||||||
};
|
};
|
||||||
|
|
||||||
memset(&grad, 0, sizeof(grad));
|
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||||
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
||||||
dx.data(), pool);
|
dx.data(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
||||||
memset(&grad_scalar, 0, sizeof(grad_scalar));
|
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
|
||||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||||
dx_scalar.data(), N, K);
|
dx_scalar.data(), N, K);
|
||||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
||||||
|
|
@ -261,9 +261,7 @@ HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
|
||||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
|
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
|
||||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
|
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
|
||||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
|
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
|
||||||
#ifdef HWY_AFTER_TEST
|
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,6 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
|
||||||
} // namespace gcpp
|
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@
|
||||||
#define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
|
||||||
|
#include "hwy/highway.h" // IWYU pragma: keep
|
||||||
#include "gemma/ops.h"
|
#include "gemma/ops.h"
|
||||||
#include "hwy/highway.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,8 @@ template <typename ArrayT>
|
||||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
||||||
const float scaling, float* HWY_RESTRICT output,
|
const float scaling, float* HWY_RESTRICT output,
|
||||||
size_t model_dim) {
|
size_t model_dim) {
|
||||||
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) {
|
HWY_ASSERT(!prompt.empty());
|
||||||
|
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||||
int token = prompt[pos];
|
int token = prompt[pos];
|
||||||
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
|
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
|
||||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||||
|
|
@ -74,8 +75,9 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
|
||||||
size_t context_size,
|
size_t context_size,
|
||||||
size_t vocab_size,
|
size_t vocab_size,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
|
HWY_ASSERT(!prompt.empty());
|
||||||
float loss = 0.0f;
|
float loss = 0.0f;
|
||||||
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) {
|
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||||
if (pos + 1 < context_size) {
|
if (pos + 1 < context_size) {
|
||||||
continue; // next token is part of context, don't try to predict it
|
continue; // next token is part of context, don't try to predict it
|
||||||
}
|
}
|
||||||
|
|
@ -271,8 +273,8 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||||
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
|
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(forward.probs.data(), forward.logits.data(),
|
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||||
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
|
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,9 @@
|
||||||
#define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
|
||||||
|
#include "hwy/highway.h" // IWYU pragma: keep
|
||||||
#include "gemma/forward-inl.h"
|
#include "gemma/forward-inl.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "hwy/highway.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,8 @@ void GatedGelu(const T* in, T* out, size_t N, size_t K) {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
|
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
|
||||||
T* y, size_t N) {
|
T* y, size_t N) {
|
||||||
for (size_t i = 0; i + 1 < tokens.size(); ++i) {
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
int token = tokens[i];
|
int token = tokens[i];
|
||||||
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
|
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
|
||||||
MulByConstT(scaling, y + i * N, N);
|
MulByConstT(scaling, y + i * N, N);
|
||||||
|
|
@ -230,11 +231,13 @@ void ApplyLayer(const Layer<T, TConfig>& weights,
|
||||||
template<typename T>
|
template<typename T>
|
||||||
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
|
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
|
||||||
T loss = {};
|
T loss = {};
|
||||||
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) {
|
const std::vector<int> tokens = prompt.tokens;
|
||||||
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
if (i + 1 < prompt.context_size) {
|
if (i + 1 < prompt.context_size) {
|
||||||
continue; // next token is part of context, don't try to predict it
|
continue; // next token is part of context, don't try to predict it
|
||||||
}
|
}
|
||||||
const int next_token = prompt.tokens[i + 1];
|
const int next_token = tokens[i + 1];
|
||||||
loss += std::log(x[i * V + next_token]);
|
loss += std::log(x[i * V + next_token]);
|
||||||
}
|
}
|
||||||
T scaling = -1.0 / std::log(2.0);
|
T scaling = -1.0 / std::log(2.0);
|
||||||
|
|
@ -248,10 +251,11 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
static constexpr size_t kLayers = TConfig::kLayers;
|
||||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
const std::vector<int> tokens = prompt.tokens;
|
||||||
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
|
||||||
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
||||||
InputEmbedding(weights.embedder_input_embedding.data(), prompt.tokens,
|
InputEmbedding(weights.embedder_input_embedding.data(), tokens,
|
||||||
kEmbScaling, forward.layers[0].input.data(), kModelDim);
|
kEmbScaling, forward.layers[0].input.data(), kModelDim);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||||
|
|
|
||||||
|
|
@ -874,6 +874,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
|
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
|
||||||
m*theta_i. However in the Gemma implementation we choose to rotate
|
m*theta_i. However in the Gemma implementation we choose to rotate
|
||||||
the pairs of dimensions v_{i} and v_{i + d//2} instead.
|
the pairs of dimensions v_{i} and v_{i + d//2} instead.
|
||||||
|
|
||||||
|
pos parameter is deliberately an int because in the backward pass we
|
||||||
|
call this with negative values (for the VJP calculation we need the transpose
|
||||||
|
of this rotation matrix which is simply the same matrix with -pos parameter)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute gradient with the finite difference method in the complex plane.
|
// Compute gradient with the finite difference method in the complex plane.
|
||||||
// If f : R->R is the tested function and F : C->C is its extenstion on the
|
// If f : R->R is the tested function and F : C->C is its extension on the
|
||||||
// complex plane so that F is complex differentiable in x, then
|
// complex plane so that F is complex differentiable in x, then
|
||||||
//
|
//
|
||||||
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
|
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
|
||||||
|
|
@ -117,7 +117,7 @@ void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
|
||||||
// F'(x) ~= Imag(F(x + ih)) / h
|
// F'(x) ~= Imag(F(x + ih)) / h
|
||||||
//
|
//
|
||||||
// This method is more numerically stable than the real-valued finite difference
|
// This method is more numerically stable than the real-valued finite difference
|
||||||
// method since we don't need to substract floating point numbers that are near
|
// method since we don't need to subtract floating point numbers that are near
|
||||||
// to each other.
|
// to each other.
|
||||||
template<typename T, typename U, size_t N, typename FUNC>
|
template<typename T, typename U, size_t N, typename FUNC>
|
||||||
void TestGradient(const std::array<T, N>& grad,
|
void TestGradient(const std::array<T, N>& grad,
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/stats.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -66,17 +67,12 @@ void ZeroInitWeights(Model model, ByteStorageT& weights,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void LogVec(const char* name, const float* data, size_t len) {
|
void LogVec(const char* name, const float* data, size_t len) {
|
||||||
float minval = std::numeric_limits<float>::max();
|
hwy::Stats stats;
|
||||||
float maxval = std::numeric_limits<float>::min();
|
|
||||||
double sum = 0.0f;
|
|
||||||
for (size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
minval = std::min(minval, data[i]);
|
stats.Notify(data[i]);
|
||||||
maxval = std::max(maxval, data[i]);
|
|
||||||
sum += data[i];
|
|
||||||
}
|
}
|
||||||
float avg = sum / len;
|
|
||||||
printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
|
printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
|
||||||
name, len, minval, avg, maxval);
|
name, len, stats.Min(), stats.Mean(), stats.Max());
|
||||||
}
|
}
|
||||||
|
|
||||||
class WeightLogger {
|
class WeightLogger {
|
||||||
|
|
|
||||||
134
gemma/weights.h
134
gemma/weights.h
|
|
@ -127,136 +127,138 @@ ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
|
||||||
return weights_u8;
|
return weights_u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_TOP_FUNC1(name, member) func(name, weights1.member)
|
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
|
||||||
#define CALL_TOP_FUNC2(name, member) \
|
#define GEMMA_CALL_TOP_FUNC2(name, member) \
|
||||||
func(name, weights1.member, weights2.member)
|
func(name, weights1.member, weights2.member)
|
||||||
#define CALL_TOP_FUNC3(name, member) \
|
#define GEMMA_CALL_TOP_FUNC3(name, member) \
|
||||||
func(name, weights1.member, weights2.member, weights3.member)
|
func(name, weights1.member, weights2.member, weights3.member)
|
||||||
#define CALL_TOP_FUNC4(name, member) \
|
#define GEMMA_CALL_TOP_FUNC4(name, member) \
|
||||||
func(name, weights1.member, weights2.memeber, \
|
func(name, weights1.member, weights2.memeber, \
|
||||||
weights3.member, weights4.member)
|
weights3.member, weights4.member)
|
||||||
|
|
||||||
#define CALL_LAYER_FUNC1(name, member) \
|
#define GEMMA_CALL_LAYER_FUNC1(name, member) \
|
||||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
func(name_buf, layer1.member)
|
func(name_buf, layer1.member)
|
||||||
|
|
||||||
#define CALL_LAYER_FUNC2(name, member) \
|
#define GEMMA_CALL_LAYER_FUNC2(name, member) \
|
||||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
func(name_buf, layer1.member, layer2.member)
|
func(name_buf, layer1.member, layer2.member)
|
||||||
|
|
||||||
#define CALL_LAYER_FUNC3(name, member) \
|
#define GEMMA_CALL_LAYER_FUNC3(name, member) \
|
||||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
func(name_buf, layer1.member, layer2.member, layer3.member)
|
func(name_buf, layer1.member, layer2.member, layer3.member)
|
||||||
|
|
||||||
#define CALL_LAYER_FUNC4(name, member) \
|
#define GEMMA_CALL_LAYER_FUNC4(name, member) \
|
||||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
func(name_buf, layer1.member, layer2.member, layer4.member)
|
func(name_buf, layer1.member, layer2.member, layer4.member)
|
||||||
|
|
||||||
#define CALL_ALL_LAYER_FUNC(N) \
|
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \
|
||||||
if (type == LayerAttentionType::kGemma) { \
|
if (type == LayerAttentionType::kGemma) { \
|
||||||
CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \
|
||||||
CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \
|
||||||
} else { \
|
} else { \
|
||||||
CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \
|
||||||
CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \
|
||||||
CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \
|
||||||
CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \
|
||||||
CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \
|
||||||
CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \
|
||||||
CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \
|
||||||
CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \
|
||||||
CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \
|
||||||
CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \
|
||||||
CALL_LAYER_FUNC ## N("gr_a", griffin.a); \
|
GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \
|
||||||
} \
|
} \
|
||||||
CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
|
||||||
CALL_LAYER_FUNC ## N("linear_w", linear_w); \
|
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
|
||||||
CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
|
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
|
||||||
if (TConfig::kPostNormScale) { \
|
if (TConfig::kPostNormScale) { \
|
||||||
CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
|
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
|
||||||
CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
|
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
|
||||||
} \
|
} \
|
||||||
CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \
|
GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \
|
||||||
if (TConfig::kFFBiases) { \
|
if (TConfig::kFFBiases) { \
|
||||||
CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \
|
||||||
CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \
|
||||||
} \
|
} \
|
||||||
if (TConfig::kSoftmaxAttnOutputBiases && \
|
if (TConfig::kSoftmaxAttnOutputBiases && \
|
||||||
type == LayerAttentionType::kGemma) { \
|
type == LayerAttentionType::kGemma) { \
|
||||||
CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename T, typename TConfig, class Func>
|
||||||
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
|
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
|
||||||
CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
||||||
CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
CALL_ALL_LAYER_FUNC(1)
|
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename T, typename TConfig, class Func>
|
||||||
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
|
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
|
||||||
CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
||||||
CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
CALL_ALL_LAYER_FUNC(1)
|
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename T, typename TConfig, class Func>
|
||||||
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
||||||
Weights<T, TConfig>& weights2) {
|
Weights<T, TConfig>& weights2) {
|
||||||
CALL_TOP_FUNC2("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
|
||||||
CALL_TOP_FUNC2("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
||||||
CALL_ALL_LAYER_FUNC(2)
|
GEMMA_CALL_ALL_LAYER_FUNC(2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef CALL_TOP_FUNC1
|
#undef GEMMA_CALL_TOP_FUNC1
|
||||||
#undef CALL_TOP_FUNC2
|
#undef GEMMA_CALL_TOP_FUNC2
|
||||||
#undef CALL_TOP_FUNC3
|
#undef GEMMA_CALL_TOP_FUNC3
|
||||||
#undef CALL_TOP_FUNC4
|
#undef GEMMA_CALL_TOP_FUNC4
|
||||||
#undef CALL_LAYER_FUNC1
|
#undef GEMMA_CALL_LAYER_FUNC1
|
||||||
#undef CALL_LAYER_FUNC2
|
#undef GEMMA_CALL_LAYER_FUNC2
|
||||||
#undef CALL_LAYER_FUNC3
|
#undef GEMMA_CALL_LAYER_FUNC3
|
||||||
#undef CALL_LAYER_FUNC4
|
#undef GEMMA_CALL_LAYER_FUNC4
|
||||||
#undef CALL_ALL_LAYER_FUNC
|
#undef GEMMA_CALL_ALL_LAYER_FUNC
|
||||||
|
|
||||||
template<typename T, typename TConfig>
|
template<typename T, typename TConfig>
|
||||||
void ZeroInit(Weights<T, TConfig>& w) {
|
void ZeroInit(Weights<T, TConfig>& w) {
|
||||||
memset(&w.embedder_input_embedding, 0, sizeof(w.embedder_input_embedding));
|
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||||
memset(&w.final_norm_scale, 0, sizeof(w.final_norm_scale));
|
sizeof(w.embedder_input_embedding));
|
||||||
|
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||||
memset(w.GetLayer(i), 0, sizeof(*w.GetLayer(i)));
|
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, typename TConfig>
|
template<typename T, typename TConfig>
|
||||||
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
|
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
|
||||||
memcpy(&dst.embedder_input_embedding, &src.embedder_input_embedding,
|
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||||
sizeof(src.embedder_input_embedding));
|
sizeof(src.embedder_input_embedding));
|
||||||
memcpy(&dst.final_norm_scale, &src.final_norm_scale,
|
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||||
sizeof(src.final_norm_scale));
|
sizeof(src.final_norm_scale));
|
||||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||||
memcpy(dst.GetLayer(i), src.GetLayer(i), sizeof(*dst.GetLayer(i)));
|
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Owns weights and undoes the type erasure of AllocateWeights.
|
||||||
template<typename T, typename TConfig>
|
template<typename T, typename TConfig>
|
||||||
class WeightsWrapper {
|
class WeightsWrapper {
|
||||||
public:
|
public:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue