Add config for att/final cap, skip max-subtract. Fixes #278

Also update includes/deps for backprop/.

PiperOrigin-RevId: 648399222
This commit is contained in:
Jan Wassenberg 2024-07-01 09:44:41 -07:00 committed by Copybara-Service
parent da7507e6f0
commit e588a7f45d
13 changed files with 94 additions and 100 deletions

View File

@ -361,6 +361,8 @@ cc_test(
],
deps = [
":backprop_scalar",
":common",
":gemma_lib",
":prompt",
":sampler",
":weights_raw",
@ -382,6 +384,7 @@ cc_test(
deps = [
":backprop",
":backprop_scalar",
":common",
":gemma_lib",
":ops",
":sampler",

View File

@ -307,9 +307,9 @@ void LayerVJP(const LayerT<TConfig>& weights,
}
}
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
static HWY_NOINLINE void SoftcapVJP(const float cap,
const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const float cap,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
@ -318,25 +318,11 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
const auto one = hn::Set(d, 1.0f);
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
// TODO(szabadka): Investigate what to do when the argmax is not unique.
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
size_t imax = std::max_element(forward, forward + size) - forward;
hn::Transform1(
d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y);
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
backward[imax] = 0;
auto sum = hn::Zero(d);
Foreach(d, backward, size, sum,
[&sum](const auto d, const auto value) HWY_ATTR {
sum = hn::Add(sum, value);
});
backward[imax] = -hn::ReduceSum(d, sum);
hn::Transform1(d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
}
static HWY_NOINLINE void CrossEntropyLossGrad(
@ -385,9 +371,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
kVocabSize);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize);
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, kVocabSize);
}
}
MatMulVJP<kModelDim, kVocabSize>(

View File

@ -274,20 +274,12 @@ void LayerVJP(const Layer<T, TConfig>& weights,
num_tokens * kModelDim);
}
template<typename T>
void SoftcapVJPT(const T* y, T* dy, size_t N) {
size_t imax = std::max_element(y, y + N) - y;
T cap = 30.0;
T inv_cap = T(1.0) / cap;
template <typename T>
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
const T inv_cap = T{1.0} / static_cast<T>(cap);
for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap;
dy[i] *= (T(1.0) - scaled * scaled);
}
dy[imax] = T(0.0);
for (size_t i = 0; i < N; ++i) {
if (i != imax) {
dy[imax] -= dy[i];
}
T scaled = y[i] * inv_cap; // tanh
dy[i] *= (T{1.0} - scaled * scaled);
}
}
@ -324,10 +316,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
kVocabSize, num_tokens);
for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(forward.logits.data() + i * kVocabSize,
backward.logits.data() + i * kVocabSize,
kVocabSize);
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize,
backward.logits.data() + i * kVocabSize, kVocabSize);
}
}
MatMulVJPT(weights.embedder_input_embedding.data(),

View File

@ -16,16 +16,23 @@
#include "backprop/backward_scalar.h"
#include <stddef.h>
#include <stdio.h>
#include <string.h> // memset
#include <array>
#include <complex>
#include <limits>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/weights_raw.h"
namespace gcpp {
@ -202,18 +209,19 @@ TEST(BackPropTest, SoftcapVJP) {
std::array<TC, N> c_x;
std::array<TC, N> c_y;
constexpr float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
Softcap(c_y.data(), N);
Softcap(kCap, c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softcap(x.data(), N);
Softcap(kCap, x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftcapVJPT(x.data(), dx.data(), N);
SoftcapVJPT(kCap, x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
}
}
@ -230,10 +238,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
Prompt prompt;
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
const float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) {
prompt.context_size = 1 + (iter % 6);
RandInit(x, 1.0 * (1 << iter), gen);
Softcap(x.data(), V * K);
Softcap(kCap, x.data(), V * K);
Softmax(x.data(), V, K);
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
Complexify(x, c_x);
@ -368,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) {
}
}
struct TestConfig {
struct TestConfig : ConfigCapNoSSM {
static constexpr int kSeqLen = 18;
static constexpr int kVocabSize = 12;
static constexpr int kModelDim = 32;
@ -382,12 +391,7 @@ struct TestConfig {
static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
};
TEST(BackPropTest, LayerVJP) {

View File

@ -25,10 +25,12 @@
#include <vector>
#include "backprop/backward_scalar.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/weights_raw.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -180,7 +182,7 @@ void TestRMSNormVJP() {
}
}
struct TestConfig {
struct TestConfig : public ConfigCapNoSSM {
static constexpr int kSeqLen = 24;
static constexpr int kVocabSize = 16;
static constexpr int kModelDim = 32;
@ -194,12 +196,7 @@ struct TestConfig {
static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
};
void TestEndToEnd() {

View File

@ -266,8 +266,11 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
forward.logits.data() + pos * kVocabSize, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(TConfig::kFinalCap,
forward.logits.data() + pos * kVocabSize, kVocabSize);
}
}
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),

View File

@ -16,8 +16,6 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include <vector>
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -93,21 +93,11 @@ void Softmax(T* x, size_t N, size_t K) {
Softmax(x + i * N, N);
}
}
template<typename T>
void Softcap(T* x, size_t N) {
auto maxreal = std::real(x[0]);
size_t imax = 0;
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
imax = i;
}
}
T cap = 30.0;
T inv_cap = T(1.0) / cap;
T xmax = x[imax];
template <typename T>
void Softcap(float cap, T* x, size_t N) {
const T inv_cap = T{1.0} / static_cast<T>(cap);
for (size_t i = 0; i < N; ++i) {
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
}
}
@ -285,7 +275,10 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
if constexpr (TConfig::kFinalCap > 0.0f) {
Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
kVocabSize);
}
}
memcpy(forward.probs.data(), forward.logits.data(),

View File

@ -16,6 +16,9 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <stddef.h>
#include <stdio.h>
#include <random>
#include <vector>

View File

@ -19,6 +19,7 @@
#include <stddef.h>
#include <array>
#include <cmath>
#include <complex>
#include "gtest/gtest.h"

View File

@ -99,8 +99,19 @@ struct ConfigNoSSM {
static constexpr int kNumTensorScales = 0;
};
struct ConfigNoCapNoSSM : ConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
};
// For Gemma2 with SoftCap
struct ConfigCapNoSSM : ConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
};
template <typename TWeight>
struct ConfigGemma27B : public ConfigNoSSM {
struct ConfigGemma27B : public ConfigCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -120,7 +131,7 @@ struct ConfigGemma27B : public ConfigNoSSM {
};
template <typename TWeight>
struct ConfigGemma9B : public ConfigNoSSM {
struct ConfigGemma9B : public ConfigCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -140,7 +151,7 @@ struct ConfigGemma9B : public ConfigNoSSM {
};
template <typename TWeight>
struct ConfigGemma7B : public ConfigNoSSM {
struct ConfigGemma7B : public ConfigNoCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -160,7 +171,7 @@ struct ConfigGemma7B : public ConfigNoSSM {
};
template <typename TWeight>
struct ConfigGemma2B : public ConfigNoSSM {
struct ConfigGemma2B : public ConfigNoCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -197,6 +208,10 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
};
template <typename TWeight>
@ -251,6 +266,10 @@ struct ConfigGriffin2B {
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;

View File

@ -484,7 +484,11 @@ HWY_NOINLINE void Attention(
const float score = Dot(q, k2, kQKVDim);
head_att[pos2 % kSeqLen] = score;
}
Softmax(head_att, std::min(pos + 1, kSeqLen));
const size_t head_att_len = std::min(pos + 1, kSeqLen);
if constexpr (TConfig::kAttCap > 0.0f) {
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
}
Softmax(head_att, head_att_len);
// Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
@ -979,7 +983,10 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
logits, pool);
LogitsSoftCap(30.0f, logits, kVocabSize);
if constexpr (TConfig::kFinalCap > 0.0f) {
LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(),
kVocabSize);
}
// Barrier: must have all logits so we can subtract max.
Softmax(logits, kVocabSize);
token = sample_token(logits, kVocabSize);

View File

@ -1773,24 +1773,9 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const V vcap = hn::Set(d, cap);
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
// If we do not subtract the max as in softmax, values > 100 (which do occur)
// will all saturate to the cap, and then this function is no longer
// monotonic, which would change the results on TopK.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V vmax = vmin;
Foreach(d, x, max_pos, vmin,
[&vmax](const auto d, const auto value)
HWY_ATTR { vmax = hn::Max(vmax, value); });
vmax = hn::MaxOfLanes(d, vmax);
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
// v * vinv_cap + (-vmax*vinv_cap).
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
hn::Transform(
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
});
hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,