Add config for att/final cap, skip max-subtract. Fixes #278

Also update includes/deps for backprop/.

PiperOrigin-RevId: 648399222
This commit is contained in:
Jan Wassenberg 2024-07-01 09:44:41 -07:00 committed by Copybara-Service
parent da7507e6f0
commit e588a7f45d
13 changed files with 94 additions and 100 deletions

View File

@ -361,6 +361,8 @@ cc_test(
], ],
deps = [ deps = [
":backprop_scalar", ":backprop_scalar",
":common",
":gemma_lib",
":prompt", ":prompt",
":sampler", ":sampler",
":weights_raw", ":weights_raw",
@ -382,6 +384,7 @@ cc_test(
deps = [ deps = [
":backprop", ":backprop",
":backprop_scalar", ":backprop_scalar",
":common",
":gemma_lib", ":gemma_lib",
":ops", ":ops",
":sampler", ":sampler",

View File

@ -307,9 +307,9 @@ void LayerVJP(const LayerT<TConfig>& weights,
} }
} }
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, static HWY_NOINLINE void SoftcapVJP(const float cap,
const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward, float* HWY_RESTRICT backward,
const float cap,
const size_t size) { const size_t size) {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
@ -318,25 +318,11 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
const auto one = hn::Set(d, 1.0f); const auto one = hn::Set(d, 1.0f);
const auto vcap = hn::Set(d, cap); const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform1(d, backward, size, forward,
// TODO(szabadka): Investigate what to do when the argmax is not unique.
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
size_t imax = std::max_element(forward, forward + size) - forward;
hn::Transform1(
d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR { [&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y); const auto scaled = hn::Mul(vinv_cap, y); // = tanh
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
}); });
backward[imax] = 0;
auto sum = hn::Zero(d);
Foreach(d, backward, size, sum,
[&sum](const auto d, const auto value) HWY_ATTR {
sum = hn::Add(sum, value);
});
backward[imax] = -hn::ReduceSum(d, sum);
} }
static HWY_NOINLINE void CrossEntropyLossGrad( static HWY_NOINLINE void CrossEntropyLossGrad(
@ -385,9 +371,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
kVocabSize); kVocabSize);
} }
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(forward.logits.data() + pos * kVocabSize, SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize); backward.logits.data() + pos * kVocabSize, kVocabSize);
}
} }
MatMulVJP<kModelDim, kVocabSize>( MatMulVJP<kModelDim, kVocabSize>(

View File

@ -275,19 +275,11 @@ void LayerVJP(const Layer<T, TConfig>& weights,
} }
template <typename T> template <typename T>
void SoftcapVJPT(const T* y, T* dy, size_t N) { void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
size_t imax = std::max_element(y, y + N) - y; const T inv_cap = T{1.0} / static_cast<T>(cap);
T cap = 30.0;
T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap; T scaled = y[i] * inv_cap; // tanh
dy[i] *= (T(1.0) - scaled * scaled); dy[i] *= (T{1.0} - scaled * scaled);
}
dy[imax] = T(0.0);
for (size_t i = 0; i < N; ++i) {
if (i != imax) {
dy[imax] -= dy[i];
}
} }
} }
@ -324,10 +316,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
kVocabSize, num_tokens); kVocabSize, num_tokens);
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t i = 0; i < num_tokens; ++i) { for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(forward.logits.data() + i * kVocabSize, SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize,
backward.logits.data() + i * kVocabSize, backward.logits.data() + i * kVocabSize, kVocabSize);
kVocabSize); }
} }
MatMulVJPT(weights.embedder_input_embedding.data(), MatMulVJPT(weights.embedder_input_embedding.data(),

View File

@ -16,16 +16,23 @@
#include "backprop/backward_scalar.h" #include "backprop/backward_scalar.h"
#include <stddef.h> #include <stddef.h>
#include <stdio.h>
#include <string.h> // memset #include <string.h> // memset
#include <array> #include <array>
#include <complex> #include <complex>
#include <limits>
#include <random> #include <random>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h" #include "backprop/forward_scalar.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "backprop/test_util.h" #include "backprop/test_util.h"
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/weights_raw.h" #include "gemma/weights_raw.h"
namespace gcpp { namespace gcpp {
@ -202,18 +209,19 @@ TEST(BackPropTest, SoftcapVJP) {
std::array<TC, N> c_x; std::array<TC, N> c_x;
std::array<TC, N> c_y; std::array<TC, N> c_y;
constexpr float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) { for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen); RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x); Complexify(x, c_x);
RandInit(dy, 1.0, gen); RandInit(dy, 1.0, gen);
auto func = [&]() { auto func = [&]() {
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0])); memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
Softcap(c_y.data(), N); Softcap(kCap, c_y.data(), N);
return DotT(dy.data(), c_y.data(), N); return DotT(dy.data(), c_y.data(), N);
}; };
Softcap(x.data(), N); Softcap(kCap, x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftcapVJPT(x.data(), dx.data(), N); SoftcapVJPT(kCap, x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
} }
} }
@ -230,10 +238,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
Prompt prompt; Prompt prompt;
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
const float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) { for (int iter = 0; iter < 10; ++iter) {
prompt.context_size = 1 + (iter % 6); prompt.context_size = 1 + (iter % 6);
RandInit(x, 1.0 * (1 << iter), gen); RandInit(x, 1.0 * (1 << iter), gen);
Softcap(x.data(), V * K); Softcap(kCap, x.data(), V * K);
Softmax(x.data(), V, K); Softmax(x.data(), V, K);
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
Complexify(x, c_x); Complexify(x, c_x);
@ -368,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) {
} }
} }
struct TestConfig { struct TestConfig : ConfigCapNoSSM {
static constexpr int kSeqLen = 18; static constexpr int kSeqLen = 18;
static constexpr int kVocabSize = 12; static constexpr int kVocabSize = 12;
static constexpr int kModelDim = 32; static constexpr int kModelDim = 32;
@ -382,12 +391,7 @@ struct TestConfig {
static constexpr bool kPostNormScale = false; static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
}; };
TEST(BackPropTest, LayerVJP) { TEST(BackPropTest, LayerVJP) {

View File

@ -25,10 +25,12 @@
#include <vector> #include <vector>
#include "backprop/backward_scalar.h" #include "backprop/backward_scalar.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h" #include "backprop/forward_scalar.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "backprop/test_util.h" #include "backprop/test_util.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/weights_raw.h" #include "gemma/weights_raw.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -180,7 +182,7 @@ void TestRMSNormVJP() {
} }
} }
struct TestConfig { struct TestConfig : public ConfigCapNoSSM {
static constexpr int kSeqLen = 24; static constexpr int kSeqLen = 24;
static constexpr int kVocabSize = 16; static constexpr int kVocabSize = 16;
static constexpr int kModelDim = 32; static constexpr int kModelDim = 32;
@ -194,12 +196,7 @@ struct TestConfig {
static constexpr bool kPostNormScale = false; static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
}; };
void TestEndToEnd() { void TestEndToEnd() {

View File

@ -266,8 +266,11 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
forward.logits.data() + pos * kVocabSize, pool); forward.logits.data() + pos * kVocabSize, pool);
} }
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize); LogitsSoftCap(TConfig::kFinalCap,
forward.logits.data() + pos * kVocabSize, kVocabSize);
}
} }
hwy::CopyBytes(forward.logits.data(), forward.probs.data(), hwy::CopyBytes(forward.logits.data(), forward.probs.data(),

View File

@ -16,8 +16,6 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include <vector>
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -94,20 +94,10 @@ void Softmax(T* x, size_t N, size_t K) {
} }
} }
template <typename T> template <typename T>
void Softcap(T* x, size_t N) { void Softcap(float cap, T* x, size_t N) {
auto maxreal = std::real(x[0]); const T inv_cap = T{1.0} / static_cast<T>(cap);
size_t imax = 0;
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
imax = i;
}
}
T cap = 30.0;
T inv_cap = T(1.0) / cap;
T xmax = x[imax];
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap); x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
} }
} }
@ -285,7 +275,10 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
forward.logits.data(), kVocabSize, kModelDim, num_tokens); forward.logits.data(), kVocabSize, kModelDim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize); if constexpr (TConfig::kFinalCap > 0.0f) {
Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
kVocabSize);
}
} }
memcpy(forward.probs.data(), forward.logits.data(), memcpy(forward.probs.data(), forward.logits.data(),

View File

@ -16,6 +16,9 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <stddef.h>
#include <stdio.h>
#include <random> #include <random>
#include <vector> #include <vector>

View File

@ -19,6 +19,7 @@
#include <stddef.h> #include <stddef.h>
#include <array> #include <array>
#include <cmath>
#include <complex> #include <complex>
#include "gtest/gtest.h" #include "gtest/gtest.h"

View File

@ -99,8 +99,19 @@ struct ConfigNoSSM {
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
}; };
struct ConfigNoCapNoSSM : ConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
};
// For Gemma2 with SoftCap
struct ConfigCapNoSSM : ConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
};
template <typename TWeight> template <typename TWeight>
struct ConfigGemma27B : public ConfigNoSSM { struct ConfigGemma27B : public ConfigCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
@ -120,7 +131,7 @@ struct ConfigGemma27B : public ConfigNoSSM {
}; };
template <typename TWeight> template <typename TWeight>
struct ConfigGemma9B : public ConfigNoSSM { struct ConfigGemma9B : public ConfigCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
@ -140,7 +151,7 @@ struct ConfigGemma9B : public ConfigNoSSM {
}; };
template <typename TWeight> template <typename TWeight>
struct ConfigGemma7B : public ConfigNoSSM { struct ConfigGemma7B : public ConfigNoCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
@ -160,7 +171,7 @@ struct ConfigGemma7B : public ConfigNoSSM {
}; };
template <typename TWeight> template <typename TWeight>
struct ConfigGemma2B : public ConfigNoSSM { struct ConfigGemma2B : public ConfigNoCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
@ -197,6 +208,10 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr bool kPostNormScale = false;
static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
}; };
template <typename TWeight> template <typename TWeight>
@ -251,6 +266,10 @@ struct ConfigGriffin2B {
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr bool kPostNormScale = false;
// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 4; static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true; static constexpr bool kFFBiases = true;

View File

@ -484,7 +484,11 @@ HWY_NOINLINE void Attention(
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2 % kSeqLen] = score; head_att[pos2 % kSeqLen] = score;
} }
Softmax(head_att, std::min(pos + 1, kSeqLen)); const size_t head_att_len = std::min(pos + 1, kSeqLen);
if constexpr (TConfig::kAttCap > 0.0f) {
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
}
Softmax(head_att, head_att_len);
// Weighted summation // Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
@ -979,7 +983,10 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
MatVec<kVocabSize, TConfig::kModelDim>( MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, x, activations.even_odd.data(), weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
logits, pool); logits, pool);
LogitsSoftCap(30.0f, logits, kVocabSize); if constexpr (TConfig::kFinalCap > 0.0f) {
LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(),
kVocabSize);
}
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(logits, kVocabSize); Softmax(logits, kVocabSize);
token = sample_token(logits, kVocabSize); token = sample_token(logits, kVocabSize);

View File

@ -1773,23 +1773,8 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const V vcap = hn::Set(d, cap); const V vcap = hn::Set(d, cap);
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
// If we do not subtract the max as in softmax, values > 100 (which do occur) hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
// will all saturate to the cap, and then this function is no longer return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
// monotonic, which would change the results on TopK.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V vmax = vmin;
Foreach(d, x, max_pos, vmin,
[&vmax](const auto d, const auto value)
HWY_ATTR { vmax = hn::Max(vmax, value); });
vmax = hn::MaxOfLanes(d, vmax);
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
// v * vinv_cap + (-vmax*vinv_cap).
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
hn::Transform(
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
}); });
} }