// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "backprop/backward_scalar.h" #include #include #include // memcpy #include #include #include #include #include "gtest/gtest.h" #include "backprop/activations.h" #include "backprop/common_scalar.h" #include "backprop/forward_scalar.h" #include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" #include "compression/compress.h" #include "gemma/configs.h" #include "gemma/weights.h" namespace gcpp { TEST(BackPropTest, MatMulVJP) { static const size_t kRows = 8; static const size_t kCols = 64; static const size_t kTokens = 5; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT weights("weights", kRows, kCols); MatStorageT x("x", kTokens, kCols); MatStorageT grad("grad", kRows, kCols); MatStorageT dx("dx", kTokens, kCols); MatStorageT c_weights("c_weights", kRows, kCols); MatStorageT c_x("c_x", kTokens, kCols); MatStorageT c_y("c_y", kTokens, kRows); MatStorageT dy("dy", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); RandInit(x, 1.0 * (1 << iter), gen); RandInit(dy, 1.0, gen); Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); return DotT(dy.data(), c_y.data(), kTokens * kRows); }; grad.ZeroInit(); MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), kRows, kCols, kTokens); TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__); TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__); } } TEST(BackPropTest, MultiHeadMatMulVJP) { static const size_t kRows = 2; static const size_t kCols = 16; static const size_t kHeads = 4; static const size_t kTokens = 3; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT weights("weights", kRows, kCols * kHeads); MatStorageT x("x", kTokens, kCols * kHeads); MatStorageT grad("grad", kRows, kCols * kHeads); MatStorageT dx("dx", kTokens, kCols * kHeads); MatStorageT c_weights("c_weights", kRows, kCols * kHeads); MatStorageT c_x("c_x", kTokens, kCols * kHeads); MatStorageT c_y("c_y", kTokens, kRows); MatStorageT dy("dy", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); RandInit(x, 1.0 * (1 << iter), gen); RandInit(dy, 1.0, gen); Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, kCols, kTokens); return DotT(dy.data(), c_y.data(), kTokens * kRows); }; grad.ZeroInit(); MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), kHeads, kRows, kCols, kTokens); TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__); TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__); } } TEST(BackPropTest, RMSNormVJP) { static const size_t K = 2; static const size_t N = 64; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT weights("weights", N, 1); MatStorageT grad("grad", N, 1); MatStorageT x("x", K, N); MatStorageT dx("dx", K, N); MatStorageT dy("dy", K, N); MatStorageT c_weights("c_weights", N, 1); MatStorageT c_x("c_x", K, N); MatStorageT c_y("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); RandInit(x, 1.0 * (1 << iter), gen); Complexify(weights, c_weights); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); return DotT(dy.data(), c_y.data(), K * N); }; grad.ZeroInit(); RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), N, K); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__); } } TEST(BackPropTest, SoftmaxVJP) { static const size_t N = 64; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT x("x", N, 1); MatStorageT dx("dx", N, 1); MatStorageT dy("dy", N, 1); MatStorageT c_x("c_x", N, 1); MatStorageT c_y("c_y", N, 1); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { memcpy(c_y.data(), c_x.data(), c_x.SizeBytes()); Softmax(c_y.data(), N); return DotT(dy.data(), c_y.data(), N); }; Softmax(x.data(), N); memcpy(dx.data(), dy.data(), dx.SizeBytes()); SoftmaxVJPT(x.data(), dx.data(), N); TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); } } TEST(BackPropTest, MaskedSoftmaxVJP) { static const size_t kSeqLen = 16; static const size_t kHeads = 2; static const size_t kTokens = 14; static const size_t N = kTokens * kHeads * kSeqLen; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT x("x", N, 1); MatStorageT dy("dy", N, 1); MatStorageT dx("dx", N, 1); MatStorageT c_x("c_x", N, 1); MatStorageT c_y("c_y", N, 1); dx.ZeroInit(); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { memcpy(c_y.data(), c_x.data(), kTokens * kHeads * kSeqLen * sizeof(c_x.At(0))); MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen); return DotT(dy.data(), c_y.data(), N); }; MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen); memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0))); MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen); TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); } } TEST(BackPropTest, SoftcapVJP) { static const size_t N = 64; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT x("x", N, 1); MatStorageT dx("dx", N, 1); MatStorageT dy("dy", N, 1); MatStorageT c_x("c_x", N, 1); MatStorageT c_y("c_y", N, 1); constexpr float kCap = 30.0f; for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0))); Softcap(kCap, c_y.data(), N); return DotT(dy.data(), c_y.data(), N); }; Softcap(kCap, x.data(), N); memcpy(dx.data(), dy.data(), dx.SizeBytes()); SoftcapVJPT(kCap, x.data(), dx.data(), N); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); } } TEST(BackPropTest, CrossEntropyLossGrad) { static const size_t K = 8; static const size_t V = 64; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT x("x", K, V); MatStorageT dx("dx", K, V); MatStorageT c_x("c_x", K, V); Prompt prompt; prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; const float kCap = 30.0f; for (int iter = 0; iter < 10; ++iter) { prompt.context_size = 1 + (iter % 6); RandInit(x, 1.0 * (1 << iter), gen); Softcap(kCap, x.data(), V * K); Softmax(x.data(), V, K); CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); Complexify(x, c_x); auto func = [&]() { return CrossEntropyLoss(c_x.data(), prompt, V); }; TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__); } } TEST(BackPropTest, GatedGeluVJP) { static const size_t K = 2; static const size_t N = 64; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT x("x", K, 2 * N); MatStorageT dx("dx", K, 2 * N); MatStorageT dy("dy", K, N); MatStorageT c_x("c_x", K, 2 * N); MatStorageT c_y("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0, gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { GatedGelu(c_x.data(), c_y.data(), N, K); return DotT(dy.data(), c_y.data(), N * K); }; GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K); TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); } } TEST(BackPropTest, MaskedAttentionVJP) { static const size_t kSeqLen = 16; static const size_t kHeads = 2; static const size_t kQKVDim = 8; static const size_t kTokens = 14; static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; static const size_t kOutSize = kTokens * kHeads * kSeqLen; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT x("x", kQKVSize, 1); MatStorageT dx("dx", kQKVSize, 1); MatStorageT dy("dy", kOutSize, 1); MatStorageT c_x("c_x", kQKVSize, 1); MatStorageT c_y("c_y", kOutSize, 1); dx.ZeroInit(); c_y.ZeroInit(); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0, gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim, kSeqLen); return DotT(dy.data(), c_y.data(), kOutSize); }; MaskedAttentionVJP(x.data(), dy.data(), dx.data(), kTokens, kHeads, kQKVDim, kSeqLen); TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); } } TEST(BackPropTest, MixByAttentionVJP) { static const size_t kSeqLen = 16; static const size_t kHeads = 2; static const size_t kQKVDim = 8; static const size_t kTokens = 14; static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen; static const size_t kOutSize = kSeqLen * kHeads * kQKVDim; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT qkv("qkv", kQKVSize, 1); MatStorageT dqkv("dqkv", kQKVSize, 1); MatStorageT attn("attn", kAttnSize, 1); MatStorageT dattn("dattn", kAttnSize, 1); MatStorageT dy("dy", kOutSize, 1); MatStorageT c_qkv("c_qkv", kQKVSize, 1); MatStorageT c_attn("c_attn", kAttnSize, 1); MatStorageT c_y("c_y", kOutSize, 1); dqkv.ZeroInit(); dattn.ZeroInit(); c_y.ZeroInit(); for (int iter = 0; iter < 10; ++iter) { RandInit(qkv, 1.0, gen); RandInit(attn, 1.0, gen); Complexify(qkv, c_qkv); Complexify(attn, c_attn); RandInit(dy, 1.0, gen); auto func = [&]() { MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(), kTokens, kHeads, kQKVDim, kSeqLen); return DotT(dy.data(), c_y.data(), kOutSize); }; MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(), dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen); TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__); TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__); } } TEST(BackPropTest, InputEmbeddingVJP) { static const size_t kSeqLen = 8; static const size_t kVocabSize = 4; static const size_t kModelDim = 16; std::mt19937 gen(42); using T = double; using TC = std::complex; MatStorageT weights("weights", kVocabSize, kModelDim); MatStorageT grad("grad", kVocabSize, kModelDim); MatStorageT dy("dy", kSeqLen, kModelDim); MatStorageT c_weights("c_weights", kVocabSize, kModelDim); MatStorageT c_y("c_y", kSeqLen, kModelDim); std::vector tokens = { 0, 1, 2, 3, 0, 1, 2 }; size_t num_tokens = tokens.size() - 1; for (size_t iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0, gen); RandInit(dy, 1.0, gen); Complexify(weights, c_weights); auto func = [&]() { InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim); return DotT(dy.data(), c_y.data(), num_tokens * kModelDim); }; grad.ZeroInit(); InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(), kModelDim); TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__); } } static ModelConfig TestConfig() { ModelConfig config; config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; config.model_dim = 32; config.vocab_size = 12; config.seq_len = 18; LayerConfig layer_config = { .model_dim = config.model_dim, .ff_hidden_dim = 48, .heads = 3, .kv_heads = 1, .qkv_dim = 12, }; config.layer_configs = {2, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); // This is required for optimize_test to pass. config.final_cap = 30.0f; return config; } TEST(BackPropTest, LayerVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; ModelConfig config = TestConfig(); const size_t kOutputSize = config.seq_len * config.model_dim; LayerWeightsPtrs weights(config.layer_configs[0]); LayerWeightsPtrs grad(config.layer_configs[0]); ForwardLayer forward(config.layer_configs[0], config.seq_len); ForwardLayer backward(config.layer_configs[0], config.seq_len); LayerWeightsPtrs c_weights(config.layer_configs[0]); ForwardLayer c_forward(config.layer_configs[0], config.seq_len); MatStorageT y("y", kOutputSize, 1); MatStorageT dy("dy", kOutputSize, 1); MatStorageT c_y("c_y", kOutputSize, 1); const size_t num_tokens = 3; std::vector layer_storage; weights.Allocate(layer_storage); grad.Allocate(layer_storage); c_weights.Allocate(layer_storage); backward.input.ZeroInit(); for (size_t iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0, gen); RandInit(forward.input, 1.0, gen); RandInit(dy, 1.0, gen); Complexify(weights, c_weights); Complexify(forward.input, c_forward.input); auto func = [&]() { ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim); }; grad.ZeroInit(/*layer_idx=*/0); ApplyLayer(weights, forward, num_tokens, y.data()); LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens); TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11, __LINE__); TestGradient(grad, c_weights, func, 1e-11); } } TEST(BackPropTest, EndToEnd) { std::mt19937 gen(42); using T = double; using TC = std::complex; ModelConfig config = TestConfig(); WeightsWrapper weights(config); WeightsWrapper grad(config); ForwardPass forward(config); ForwardPass backward(config); WeightsWrapper c_weights(config); ForwardPass c_forward(config); ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0, gen); CrossEntropyLossForwardPass(prompt, weights.get(), forward); grad.ZeroInit(); CrossEntropyLossBackwardPass( prompt, weights.get(), forward, grad.get(), backward); Complexify(weights.get(), c_weights.get()); auto func = [&]() { return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); }; TestGradient(grad.get(), c_weights.get(), func, 1e-11); } } template void MulByConstAndAddT(T c, const LayerWeightsPtrs& x, LayerWeightsPtrs& out) { MulByConstAndAddT(c, x.pre_attention_norm_scale, out.pre_attention_norm_scale); MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w); MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale); MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w); MulByConstAndAddT(c, x.linear_w, out.linear_w); } template void MulByConstAndAddT(T c, const ModelWeightsPtrs& x, ModelWeightsPtrs& out) { const size_t layers = x.c_layers.size(); MulByConstAndAddT(c, x.embedder_input_embedding, out.embedder_input_embedding); MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); for (size_t i = 0; i < layers; ++i) { MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); } } // Evaluates forward pass on a batch. template T CrossEntropyLossForwardPass(const std::vector& batch, const WeightsWrapper& weights, ForwardPass& forward) { T loss = 0.0; for (const Prompt& prompt : batch) { loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); } T scale = 1.0 / batch.size(); return loss * scale; } // Evaluates forward pass on a batch by applying gradient with the given // learning rate. Does not update weights, but uses the given tmp weights // instead. template T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, const WeightsWrapper& weights, const WeightsWrapper& grad, WeightsWrapper& tmp, ForwardPass& forward) { tmp.CopyFrom(weights); const T scale = -learning_rate / batch.size(); MulByConstAndAddT(scale, grad.get(), tmp.get()); return CrossEntropyLossForwardPass(batch, tmp, forward); } // Uses line search in the negative gradient direction to update weights. We do // this so that we can test that each step during the gradient descent can // decrease the objective function value. template T FindOptimalUpdate(const WeightsWrapper& grad, WeightsWrapper& weights, WeightsWrapper& tmp, ForwardPass& forward, const std::vector& batch, T loss, T initial_learning_rate) { T lr0 = initial_learning_rate; T loss0 = CrossEntropyLossForwardPass( lr0, batch, weights, grad, tmp, forward); for (size_t iter = 0; iter < 30; ++iter) { T lr1 = lr0 * 0.5; T loss1 = CrossEntropyLossForwardPass( lr1, batch, weights, grad, tmp, forward); if (loss0 < loss && loss1 >= loss0) { break; } loss0 = loss1; lr0 = lr1; } for (size_t iter = 0; iter < 30; ++iter) { T lr1 = lr0 * 2.0; T loss1 = CrossEntropyLossForwardPass( lr1, batch, weights, grad, tmp, forward); if (loss1 >= loss0) { break; } loss0 = loss1; lr0 = lr1; } const T scale = -lr0 / batch.size(); MulByConstAndAddT(scale, grad.get(), weights.get()); return lr0; } TEST(BackProptest, Convergence) { std::mt19937 gen(42); using T = float; using TC = std::complex; ModelConfig config = TestConfig(); WeightsWrapper weights(config); WeightsWrapper grad(config); WeightsWrapper tmp(config); ForwardPass forward(config); ForwardPass backward(config); WeightsWrapper c_weights(config); ForwardPass c_forward(config); constexpr size_t kBatchSize = 5; ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); T learning_rate = 0.01; RandInit(weights.get(), T(1.0), gen); printf("Sample batch:\n"); for (size_t i = 0; i < 10; ++i) { ReverseSequenceSampler::LogPrompt(training_task.Sample(gen)); } T prev_loss = std::numeric_limits::max(); bool stop = false; size_t step = 0; while (!stop) { T loss = 0.0; grad.ZeroInit(); std::mt19937 sgen(42); std::vector batch = training_task.SampleBatch(kBatchSize, sgen); for (const Prompt& prompt : batch) { loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); CrossEntropyLossBackwardPass( prompt, weights.get(), forward, grad.get(), backward); } if (step % 250 == 0) { printf("Checking gradient...\n"); Complexify(weights.get(), c_weights.get()); auto func = [&]() { TC scale = batch.size(); return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale; }; TestGradient(grad.get(), c_weights.get(), func, 5e-3f); } loss /= batch.size(); EXPECT_LT(loss, prev_loss); stop = step >= 10000 || loss < 1e-2; if (step % 10 == 0 || stop) { printf("step: %5zu loss: %.15f learning_rate: %.15f\n", step, loss, learning_rate); } if (!stop) { learning_rate = FindOptimalUpdate( grad, weights, tmp, forward, batch, loss, learning_rate); ++step; } prev_loss = loss; } EXPECT_LT(step, 1000); } } // namespace gcpp