mirror of https://github.com/google/gemma.cpp.git
Backprop test fixes and allocator cleanup
- Shorten backprop tests to prevent timeout - Add line number of failing test - matmul: remove unused enable_bind - allocator: we will retain enable_bind there - mat: disable cyclic padding optimization (broken) PiperOrigin-RevId: 752656068
This commit is contained in:
parent
160a5824fb
commit
fe80f10ed7
|
|
@ -56,7 +56,7 @@ TEST(BackPropTest, MatMulVJP) {
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
|
|
@ -67,8 +67,8 @@ TEST(BackPropTest, MatMulVJP) {
|
||||||
ZeroInit(grad);
|
ZeroInit(grad);
|
||||||
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||||
dx.Packed(), kRows, kCols, kTokens);
|
dx.Packed(), kRows, kCols, kTokens);
|
||||||
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
|
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-14, 1e-11, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -92,7 +92,7 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
|
|
@ -104,8 +104,8 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||||
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||||
grad.Packed(), dx.Packed(), kHeads, kRows, kCols,
|
grad.Packed(), dx.Packed(), kHeads, kRows, kCols,
|
||||||
kTokens);
|
kTokens);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -129,7 +129,7 @@ TEST(BackPropTest, RMSNormVJP) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||||
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||||
|
|
@ -137,8 +137,8 @@ TEST(BackPropTest, RMSNormVJP) {
|
||||||
ZeroInit(grad);
|
ZeroInit(grad);
|
||||||
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||||
dx.Packed(), N, K);
|
dx.Packed(), N, K);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -154,9 +154,9 @@ TEST(BackPropTest, SoftmaxVJP) {
|
||||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0f * (1 << iter), gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
CopyMat(c_x, c_y);
|
CopyMat(c_x, c_y);
|
||||||
Softmax(c_y.Packed(), N);
|
Softmax(c_y.Packed(), N);
|
||||||
|
|
@ -165,7 +165,7 @@ TEST(BackPropTest, SoftmaxVJP) {
|
||||||
Softmax(x.Packed(), N);
|
Softmax(x.Packed(), N);
|
||||||
CopyMat(dy, dx);
|
CopyMat(dy, dx);
|
||||||
SoftmaxVJPT(x.Packed(), dx.Packed(), N);
|
SoftmaxVJPT(x.Packed(), dx.Packed(), N);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -187,7 +187,7 @@ TEST(BackPropTest, MaskedSoftmaxVJP) {
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
CopyMat(c_x, c_y);
|
CopyMat(c_x, c_y);
|
||||||
MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen);
|
MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen);
|
||||||
|
|
@ -196,7 +196,7 @@ TEST(BackPropTest, MaskedSoftmaxVJP) {
|
||||||
MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen);
|
MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen);
|
||||||
CopyMat(dy, dx);
|
CopyMat(dy, dx);
|
||||||
MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen);
|
MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen);
|
||||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -215,7 +215,7 @@ TEST(BackPropTest, SoftcapVJP) {
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
CopyMat(c_x, c_y);
|
CopyMat(c_x, c_y);
|
||||||
Softcap(kCap, c_y.Packed(), N);
|
Softcap(kCap, c_y.Packed(), N);
|
||||||
|
|
@ -224,7 +224,7 @@ TEST(BackPropTest, SoftcapVJP) {
|
||||||
Softcap(kCap, x.Packed(), N);
|
Softcap(kCap, x.Packed(), N);
|
||||||
CopyMat(dy, dx);
|
CopyMat(dy, dx);
|
||||||
SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N);
|
SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -249,7 +249,7 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||||
CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V);
|
CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); };
|
auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); };
|
||||||
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -266,15 +266,15 @@ TEST(BackPropTest, GatedGeluVJP) {
|
||||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0, gen);
|
RandInit(x, 1.0f, gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
GatedGelu(c_x.Packed(), c_y.Packed(), N, K);
|
GatedGelu(c_x.Packed(), c_y.Packed(), N, K);
|
||||||
return DotT(dy.Packed(), c_y.Packed(), N * K);
|
return DotT(dy.Packed(), c_y.Packed(), N * K);
|
||||||
};
|
};
|
||||||
GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K);
|
GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -297,9 +297,9 @@ TEST(BackPropTest, MaskedAttentionVJP) {
|
||||||
ZeroInit(c_y);
|
ZeroInit(c_y);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0, gen);
|
RandInit(x, 1.0f, gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim,
|
MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim,
|
||||||
kSeqLen);
|
kSeqLen);
|
||||||
|
|
@ -307,7 +307,7 @@ TEST(BackPropTest, MaskedAttentionVJP) {
|
||||||
};
|
};
|
||||||
MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads,
|
MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads,
|
||||||
kQKVDim, kSeqLen);
|
kQKVDim, kSeqLen);
|
||||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -335,11 +335,11 @@ TEST(BackPropTest, MixByAttentionVJP) {
|
||||||
ZeroInit(c_y);
|
ZeroInit(c_y);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(qkv, 1.0, gen);
|
RandInit(qkv, 1.0f, gen);
|
||||||
RandInit(attn, 1.0, gen);
|
RandInit(attn, 1.0f, gen);
|
||||||
Complexify(qkv, c_qkv);
|
Complexify(qkv, c_qkv);
|
||||||
Complexify(attn, c_attn);
|
Complexify(attn, c_attn);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens,
|
MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens,
|
||||||
kHeads, kQKVDim, kSeqLen);
|
kHeads, kQKVDim, kSeqLen);
|
||||||
|
|
@ -347,8 +347,8 @@ TEST(BackPropTest, MixByAttentionVJP) {
|
||||||
};
|
};
|
||||||
MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(),
|
MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(),
|
||||||
dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen);
|
dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||||
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||||
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -368,8 +368,8 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
||||||
size_t num_tokens = tokens.size() - 1;
|
size_t num_tokens = tokens.size() - 1;
|
||||||
|
|
||||||
for (size_t iter = 0; iter < 10; ++iter) {
|
for (size_t iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0, gen);
|
RandInit(weights, 1.0f, gen);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0f, gen);
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(),
|
InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(),
|
||||||
|
|
@ -379,7 +379,7 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
||||||
ZeroInit(grad);
|
ZeroInit(grad);
|
||||||
InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(),
|
InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(),
|
||||||
grad.Packed(), kModelDim);
|
grad.Packed(), kModelDim);
|
||||||
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-14, 1e-14, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -441,9 +441,9 @@ TEST(BackPropTest, LayerVJP) {
|
||||||
grad.ZeroInit(/*layer_idx=*/0);
|
grad.ZeroInit(/*layer_idx=*/0);
|
||||||
ApplyLayer(weights, forward, num_tokens, y.Packed());
|
ApplyLayer(weights, forward, num_tokens, y.Packed());
|
||||||
LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens);
|
LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens);
|
||||||
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
|
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11, __LINE__,
|
||||||
__LINE__);
|
__LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-11);
|
TestGradient(grad, c_weights, func, 2e-11, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -475,7 +475,7 @@ TEST(BackPropTest, EndToEnd) {
|
||||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||||
};
|
};
|
||||||
|
|
||||||
TestGradient(grad.get(), c_weights.get(), func, 1e-11);
|
TestGradient(grad.get(), c_weights.get(), func, 1e-11, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -611,12 +611,12 @@ TEST(BackProptest, Convergence) {
|
||||||
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
|
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
|
||||||
};
|
};
|
||||||
|
|
||||||
TestGradient(grad.get(), c_weights.get(), func, 5e-3f);
|
TestGradient(grad.get(), c_weights.get(), func, 5e-3f, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
||||||
loss /= batch.size();
|
loss /= batch.size();
|
||||||
EXPECT_LT(loss, prev_loss);
|
EXPECT_LT(loss, prev_loss);
|
||||||
stop = step >= 10000 || loss < 1e-2;
|
stop = step >= 1000 || loss < T{1.0};
|
||||||
if (step % 10 == 0 || stop) {
|
if (step % 10 == 0 || stop) {
|
||||||
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
|
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
|
||||||
step, loss, learning_rate);
|
step, loss, learning_rate);
|
||||||
|
|
|
||||||
|
|
@ -103,14 +103,14 @@ void TestMatMulVJP() {
|
||||||
ZeroInit(grad);
|
ZeroInit(grad);
|
||||||
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
|
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
|
||||||
grad.Packed(), dx.Packed(), pool);
|
grad.Packed(), dx.Packed(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||||
|
|
||||||
ZeroInit(grad_scalar);
|
ZeroInit(grad_scalar);
|
||||||
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||||
dx_scalar.Packed(), kRows, kCols, kTokens);
|
dx_scalar.Packed(), kRows, kCols, kTokens);
|
||||||
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
|
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__, __LINE__);
|
||||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -148,15 +148,15 @@ void TestMultiHeadMatMulVJP() {
|
||||||
ZeroInit(grad);
|
ZeroInit(grad);
|
||||||
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
|
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
|
||||||
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
|
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||||
|
|
||||||
ZeroInit(grad_scalar);
|
ZeroInit(grad_scalar);
|
||||||
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||||
grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows,
|
grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows,
|
||||||
kCols, kTokens);
|
kCols, kTokens);
|
||||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__, __LINE__);
|
||||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,14 +191,14 @@ void TestRMSNormVJP() {
|
||||||
ZeroInit(grad);
|
ZeroInit(grad);
|
||||||
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
|
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
|
||||||
dx.Packed(), pool);
|
dx.Packed(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||||
|
|
||||||
ZeroInit(grad_scalar);
|
ZeroInit(grad_scalar);
|
||||||
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||||
dx_scalar.Packed(), N, K);
|
dx_scalar.Packed(), N, K);
|
||||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__, __LINE__);
|
||||||
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
|
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -265,7 +265,7 @@ void TestEndToEnd() {
|
||||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||||
};
|
};
|
||||||
|
|
||||||
TestGradient(grad.get(), c_weights.get(), func, 2e-3f);
|
TestGradient(grad.get(), c_weights.get(), func, 2e-3f, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -109,16 +109,18 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
gemma.MutableWeights().LogWeightStats();
|
gemma.MutableWeights().LogWeightStats();
|
||||||
|
|
||||||
constexpr size_t kBatchSize = 8;
|
constexpr size_t kBatchSize = 8;
|
||||||
const float alpha = 0.001f;
|
constexpr float kAlpha = 0.001f;
|
||||||
const float beta1 = 0.9f;
|
constexpr float kBeta1 = 0.9f;
|
||||||
const float beta2 = 0.999f;
|
constexpr float kBeta2 = 0.999f;
|
||||||
const float epsilon = 1e-8f;
|
constexpr float kEpsilon = 1e-8f;
|
||||||
|
|
||||||
|
constexpr float kMaxLoss = 20.0f;
|
||||||
|
|
||||||
ReverseSequenceSampler training_task({
|
ReverseSequenceSampler training_task({
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||||
size_t steps = 0;
|
size_t steps = 0;
|
||||||
size_t num_ok;
|
size_t num_ok;
|
||||||
for (; steps < 1000000; ++steps) {
|
for (; steps < 1000; ++steps) {
|
||||||
std::mt19937 sgen(42);
|
std::mt19937 sgen(42);
|
||||||
grad.ZeroInit();
|
grad.ZeroInit();
|
||||||
float total_loss = 0.0f;
|
float total_loss = 0.0f;
|
||||||
|
|
@ -136,7 +138,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
}
|
}
|
||||||
total_loss /= kBatchSize;
|
total_loss /= kBatchSize;
|
||||||
|
|
||||||
AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
AdamUpdate(info.weight, grad, kAlpha, kBeta1, kBeta2, kEpsilon, steps + 1,
|
||||||
gemma.Weights(), grad_m, grad_v, pool);
|
gemma.Weights(), grad_m, grad_v, pool);
|
||||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||||
steps, total_loss, num_ok, kBatchSize);
|
steps, total_loss, num_ok, kBatchSize);
|
||||||
|
|
@ -144,14 +146,12 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
printf("Batch gradient:\n");
|
printf("Batch gradient:\n");
|
||||||
grad.LogWeightStats();
|
grad.LogWeightStats();
|
||||||
}
|
}
|
||||||
if (total_loss < 0.5f) {
|
if (total_loss < kMaxLoss) break; // Done
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
printf("Num steps: %zu\n", steps);
|
printf("Num steps: %zu\n", steps);
|
||||||
printf("Final weights:\n");
|
printf("Final weights:\n");
|
||||||
gemma.MutableWeights().LogWeightStats();
|
gemma.MutableWeights().LogWeightStats();
|
||||||
EXPECT_LT(steps, 300);
|
EXPECT_LT(steps, 50);
|
||||||
EXPECT_EQ(num_ok, kBatchSize);
|
EXPECT_EQ(num_ok, kBatchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,13 +27,14 @@
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// TODO: make a member of Layer<T>.
|
// TODO: make a member of Layer<T>.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
void RandInit(LayerWeightsPtrs<T>& w, float stddev, std::mt19937& gen) {
|
||||||
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
||||||
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
||||||
RandInit(w.qkv_einsum_w, stddev, gen);
|
RandInit(w.qkv_einsum_w, stddev, gen);
|
||||||
|
|
@ -43,7 +44,7 @@ void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
void RandInit(ModelWeightsPtrs<T>& w, float stddev, std::mt19937& gen) {
|
||||||
const size_t kLayers = w.c_layers.size();
|
const size_t kLayers = w.c_layers.size();
|
||||||
RandInit(w.embedder_input_embedding, stddev, gen);
|
RandInit(w.embedder_input_embedding, stddev, gen);
|
||||||
RandInit(w.final_norm_scale, stddev, gen);
|
RandInit(w.final_norm_scale, stddev, gen);
|
||||||
|
|
@ -108,7 +109,9 @@ class WeightsWrapper {
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||||
double max_abs_err, double max_rel_err, int line) {
|
double max_abs_err, double max_rel_err, int line_test,
|
||||||
|
int line_util) {
|
||||||
|
// TODO: consider compensated sum.
|
||||||
double sum0 = 0;
|
double sum0 = 0;
|
||||||
double sum1 = 0;
|
double sum1 = 0;
|
||||||
double sum01 = 0;
|
double sum01 = 0;
|
||||||
|
|
@ -122,14 +125,15 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||||
ASSERT_NEAR(
|
ASSERT_NEAR(
|
||||||
actual_row[c], expected_row[c],
|
actual_row[c], expected_row[c],
|
||||||
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
|
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
|
||||||
<< "line: " << line << " r " << r << " c " << c;
|
<< "test line " << line_test << "test_util.h line " << line_util
|
||||||
|
<< " r " << r << " c " << c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (sum0 > 1e-40) {
|
if (sum0 > 1e-16) {
|
||||||
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
||||||
ASSERT_NEAR(norm_dot, 1.0, 1e-7)
|
ASSERT_NEAR(norm_dot, 1.0, 3e-6)
|
||||||
<< "line: " << line << " sum0: " << sum0 << " sum1: " << sum1
|
<< "test line " << line_test << " test_util.h line " << line_util
|
||||||
<< " sum01: " << sum01;
|
<< " sum0: " << sum0 << " sum1: " << sum1 << " sum01: " << sum01;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -148,7 +152,8 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||||
// to each other.
|
// to each other.
|
||||||
template <typename FUNC, typename T, typename U>
|
template <typename FUNC, typename T, typename U>
|
||||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
||||||
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
|
FUNC func, U step, T max_abs_err, T max_rel_err,
|
||||||
|
int line_test, int line_util) {
|
||||||
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
|
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
|
||||||
const U inv_step = 1.0 / step;
|
const U inv_step = 1.0 / step;
|
||||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||||
|
|
@ -163,49 +168,56 @@ void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
||||||
x_row[c] = x0;
|
x_row[c] = x0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
|
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line_test, line_util);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FUNC>
|
template <typename FUNC>
|
||||||
void TestGradient(const MatPtrT<float>& grad, MatPtrT<std::complex<float>>& x,
|
void TestGradient(const MatPtrT<float>& grad, MatPtrT<std::complex<float>>& x,
|
||||||
FUNC func, float max_abs_err, float max_rel_error, int line) {
|
FUNC func, float max_abs_err, float max_rel_error,
|
||||||
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line);
|
int line_test, int line_util) {
|
||||||
|
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line_test,
|
||||||
|
line_util);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FUNC, typename T>
|
template <typename FUNC, typename T>
|
||||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
|
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
|
||||||
FUNC func, T max_abs_err, T max_rel_error, int line) {
|
FUNC func, T max_abs_err, T max_rel_error, int line_test,
|
||||||
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
|
int line_util) {
|
||||||
|
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line_test,
|
||||||
|
line_util);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename FUNC>
|
template <typename T, typename U, typename FUNC>
|
||||||
void TestGradient(const LayerWeightsPtrs<T>& grad,
|
void TestGradient(const LayerWeightsPtrs<T>& grad,
|
||||||
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
|
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err,
|
||||||
|
int line_test) {
|
||||||
TestGradient(grad.pre_attention_norm_scale,
|
TestGradient(grad.pre_attention_norm_scale,
|
||||||
c_weights.pre_attention_norm_scale,
|
c_weights.pre_attention_norm_scale, func, max_err, max_err,
|
||||||
func, max_err, max_err, __LINE__);
|
line_test, __LINE__);
|
||||||
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w,
|
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, func,
|
||||||
func, max_err, max_err, __LINE__);
|
max_err, max_err, line_test, __LINE__);
|
||||||
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w,
|
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, func, max_err,
|
||||||
func, max_err, max_err, __LINE__);
|
max_err, line_test, __LINE__);
|
||||||
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale,
|
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, func,
|
||||||
func, max_err, max_err, __LINE__);
|
max_err, max_err, line_test, __LINE__);
|
||||||
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w,
|
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, func, max_err,
|
||||||
func, max_err, max_err, __LINE__);
|
max_err, line_test, __LINE__);
|
||||||
TestGradient(grad.linear_w, c_weights.linear_w,
|
TestGradient(grad.linear_w, c_weights.linear_w, func, max_err, max_err,
|
||||||
func, max_err, max_err, __LINE__);
|
line_test, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename FUNC>
|
template <typename T, typename U, typename FUNC>
|
||||||
void TestGradient(const ModelWeightsPtrs<T>& grad,
|
void TestGradient(const ModelWeightsPtrs<T>& grad,
|
||||||
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
|
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err,
|
||||||
|
int line_test) {
|
||||||
TestGradient(grad.embedder_input_embedding,
|
TestGradient(grad.embedder_input_embedding,
|
||||||
c_weights.embedder_input_embedding,
|
c_weights.embedder_input_embedding, func, 2 * max_err, max_err,
|
||||||
func, 2 * max_err, max_err, __LINE__);
|
line_test, __LINE__);
|
||||||
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
|
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, func, max_err,
|
||||||
func, max_err, max_err, __LINE__);
|
max_err, line_test, __LINE__);
|
||||||
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
|
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
|
||||||
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
|
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err,
|
||||||
|
line_test);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -613,10 +613,6 @@ struct MatMulEnv {
|
||||||
ThreadingContext2& ctx;
|
ThreadingContext2& ctx;
|
||||||
bool have_timer_stop = false;
|
bool have_timer_stop = false;
|
||||||
|
|
||||||
// Enable binding: disabled in Gemma until tensors support it, enabled in
|
|
||||||
// bench_matmul.cc.
|
|
||||||
bool enable_bind = false;
|
|
||||||
|
|
||||||
// Whether `MMCandidates()` should print the set of parameters.
|
// Whether `MMCandidates()` should print the set of parameters.
|
||||||
bool print_config = false;
|
bool print_config = false;
|
||||||
// Whether to print each config's speed during autotuning.
|
// Whether to print each config's speed during autotuning.
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
|
||||||
} else {
|
} else {
|
||||||
HWY_WARN(
|
HWY_WARN(
|
||||||
"Multiple sockets but binding disabled. This reduces speed; "
|
"Multiple sockets but binding disabled. This reduces speed; "
|
||||||
"set or remove enable_bind to avoid this warning.");
|
"set --bind 1 to avoid this warning.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -209,7 +209,7 @@ AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const {
|
||||||
if (HWY_ALIGNMENT < QuantumBytes()) {
|
if (HWY_ALIGNMENT < QuantumBytes()) {
|
||||||
HWY_WARN(
|
HWY_WARN(
|
||||||
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
|
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
|
||||||
"are huge, enable GEMMA_BIND to avoid this warning.",
|
"are huge, enable GEMMA_BIND and set --bind 1 to avoid this warning.",
|
||||||
HWY_ALIGNMENT, QuantumBytes());
|
HWY_ALIGNMENT, QuantumBytes());
|
||||||
}
|
}
|
||||||
auto p = hwy::AllocateAligned<uint8_t>(bytes);
|
auto p = hwy::AllocateAligned<uint8_t>(bytes);
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,6 @@ class Allocator2 {
|
||||||
public:
|
public:
|
||||||
// Must be called at least once before any other function. Not thread-safe,
|
// Must be called at least once before any other function. Not thread-safe,
|
||||||
// hence only call this from the main thread.
|
// hence only call this from the main thread.
|
||||||
// TODO: remove enable_bind once Gemma tensors support binding.
|
|
||||||
Allocator2(const BoundedTopology& topology, bool enable_bind);
|
Allocator2(const BoundedTopology& topology, bool enable_bind);
|
||||||
|
|
||||||
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ void CopyMat(const MatPtr& from, MatPtr& to);
|
||||||
void ZeroInit(MatPtr& mat);
|
void ZeroInit(MatPtr& mat);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
void RandInit(MatPtrT<T>& x, float stddev, std::mt19937& gen) {
|
||||||
std::normal_distribution<T> dist(0.0, stddev);
|
std::normal_distribution<T> dist(0.0, stddev);
|
||||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||||
T* row = x.Row(r);
|
T* row = x.Row(r);
|
||||||
|
|
@ -401,8 +401,9 @@ class RowPtr {
|
||||||
size_t stride)
|
size_t stride)
|
||||||
: row0_(row0),
|
: row0_(row0),
|
||||||
stride_(stride),
|
stride_(stride),
|
||||||
row_mask_(
|
// TODO: disabled because otherwise we see non-deterministic results.
|
||||||
static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
|
row_mask_(0),
|
||||||
|
// static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
|
||||||
cols_(static_cast<uint32_t>(cols)),
|
cols_(static_cast<uint32_t>(cols)),
|
||||||
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())),
|
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())),
|
||||||
quantum_bytes_(allocator.QuantumBytes()) {
|
quantum_bytes_(allocator.QuantumBytes()) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue