mirror of https://github.com/google/gemma.cpp.git
Add config for att/final cap, skip max-subtract. Fixes #278
Also update includes/deps for backprop/. PiperOrigin-RevId: 648399222
This commit is contained in:
parent
da7507e6f0
commit
e588a7f45d
|
|
@ -361,6 +361,8 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":weights_raw",
|
||||
|
|
@ -382,6 +384,7 @@ cc_test(
|
|||
deps = [
|
||||
":backprop",
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":sampler",
|
||||
|
|
|
|||
|
|
@ -307,9 +307,9 @@ void LayerVJP(const LayerT<TConfig>& weights,
|
|||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
||||
static HWY_NOINLINE void SoftcapVJP(const float cap,
|
||||
const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const float cap,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -318,25 +318,11 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
|||
const auto one = hn::Set(d, 1.0f);
|
||||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
|
||||
// TODO(szabadka): Investigate what to do when the argmax is not unique.
|
||||
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
|
||||
size_t imax = std::max_element(forward, forward + size) - forward;
|
||||
|
||||
hn::Transform1(
|
||||
d, backward, size, forward,
|
||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||
const auto scaled = hn::Mul(vinv_cap, y);
|
||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||
});
|
||||
|
||||
backward[imax] = 0;
|
||||
auto sum = hn::Zero(d);
|
||||
Foreach(d, backward, size, sum,
|
||||
[&sum](const auto d, const auto value) HWY_ATTR {
|
||||
sum = hn::Add(sum, value);
|
||||
});
|
||||
backward[imax] = -hn::ReduceSum(d, sum);
|
||||
hn::Transform1(d, backward, size, forward,
|
||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
|
||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||
|
|
@ -385,9 +371,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
kVocabSize);
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(forward.logits.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, kVocabSize>(
|
||||
|
|
|
|||
|
|
@ -274,20 +274,12 @@ void LayerVJP(const Layer<T, TConfig>& weights,
|
|||
num_tokens * kModelDim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
||||
size_t imax = std::max_element(y, y + N) - y;
|
||||
T cap = 30.0;
|
||||
T inv_cap = T(1.0) / cap;
|
||||
template <typename T>
|
||||
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
|
||||
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
T scaled = y[i] * inv_cap;
|
||||
dy[i] *= (T(1.0) - scaled * scaled);
|
||||
}
|
||||
dy[imax] = T(0.0);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
if (i != imax) {
|
||||
dy[imax] -= dy[i];
|
||||
}
|
||||
T scaled = y[i] * inv_cap; // tanh
|
||||
dy[i] *= (T{1.0} - scaled * scaled);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -324,10 +316,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
||||
kVocabSize, num_tokens);
|
||||
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(forward.logits.data() + i * kVocabSize,
|
||||
backward.logits.data() + i * kVocabSize,
|
||||
kVocabSize);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize,
|
||||
backward.logits.data() + i * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.embedder_input_embedding.data(),
|
||||
|
|
|
|||
|
|
@ -16,16 +16,23 @@
|
|||
#include "backprop/backward_scalar.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // memset
|
||||
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights_raw.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -202,18 +209,19 @@ TEST(BackPropTest, SoftcapVJP) {
|
|||
std::array<TC, N> c_x;
|
||||
std::array<TC, N> c_y;
|
||||
|
||||
constexpr float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
|
||||
Softcap(c_y.data(), N);
|
||||
Softcap(kCap, c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
Softcap(x.data(), N);
|
||||
Softcap(kCap, x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
||||
SoftcapVJPT(x.data(), dx.data(), N);
|
||||
SoftcapVJPT(kCap, x.data(), dx.data(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -230,10 +238,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
|||
Prompt prompt;
|
||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||
|
||||
const float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
prompt.context_size = 1 + (iter % 6);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Softcap(x.data(), V * K);
|
||||
Softcap(kCap, x.data(), V * K);
|
||||
Softmax(x.data(), V, K);
|
||||
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
||||
Complexify(x, c_x);
|
||||
|
|
@ -368,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
|||
}
|
||||
}
|
||||
|
||||
struct TestConfig {
|
||||
struct TestConfig : ConfigCapNoSSM {
|
||||
static constexpr int kSeqLen = 18;
|
||||
static constexpr int kVocabSize = 12;
|
||||
static constexpr int kModelDim = 32;
|
||||
|
|
@ -382,12 +391,7 @@ struct TestConfig {
|
|||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
TEST(BackPropTest, LayerVJP) {
|
||||
|
|
|
|||
|
|
@ -25,10 +25,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "backprop/backward_scalar.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights_raw.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -180,7 +182,7 @@ void TestRMSNormVJP() {
|
|||
}
|
||||
}
|
||||
|
||||
struct TestConfig {
|
||||
struct TestConfig : public ConfigCapNoSSM {
|
||||
static constexpr int kSeqLen = 24;
|
||||
static constexpr int kVocabSize = 16;
|
||||
static constexpr int kModelDim = 32;
|
||||
|
|
@ -194,12 +196,7 @@ struct TestConfig {
|
|||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
void TestEndToEnd() {
|
||||
|
|
|
|||
|
|
@ -266,8 +266,11 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
|||
forward.logits.data() + pos * kVocabSize, pool);
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(TConfig::kFinalCap,
|
||||
forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
|
|||
|
|
@ -93,21 +93,11 @@ void Softmax(T* x, size_t N, size_t K) {
|
|||
Softmax(x + i * N, N);
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softcap(T* x, size_t N) {
|
||||
auto maxreal = std::real(x[0]);
|
||||
size_t imax = 0;
|
||||
for (size_t i = 1; i < N; ++i) {
|
||||
if (std::real(x[i]) > maxreal) {
|
||||
maxreal = std::real(x[i]);
|
||||
imax = i;
|
||||
}
|
||||
}
|
||||
T cap = 30.0;
|
||||
T inv_cap = T(1.0) / cap;
|
||||
T xmax = x[imax];
|
||||
template <typename T>
|
||||
void Softcap(float cap, T* x, size_t N) {
|
||||
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
|
||||
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -285,7 +275,10 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
|||
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
|
||||
kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(forward.probs.data(), forward.logits.data(),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
|
|
|||
|
|
@ -99,8 +99,19 @@ struct ConfigNoSSM {
|
|||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
struct ConfigNoCapNoSSM : ConfigNoSSM {
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
};
|
||||
|
||||
// For Gemma2 with SoftCap
|
||||
struct ConfigCapNoSSM : ConfigNoSSM {
|
||||
static constexpr float kAttCap = 50.0f;
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma27B : public ConfigNoSSM {
|
||||
struct ConfigGemma27B : public ConfigCapNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -120,7 +131,7 @@ struct ConfigGemma27B : public ConfigNoSSM {
|
|||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma9B : public ConfigNoSSM {
|
||||
struct ConfigGemma9B : public ConfigCapNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -140,7 +151,7 @@ struct ConfigGemma9B : public ConfigNoSSM {
|
|||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma7B : public ConfigNoSSM {
|
||||
struct ConfigGemma7B : public ConfigNoCapNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -160,7 +171,7 @@ struct ConfigGemma7B : public ConfigNoSSM {
|
|||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2B : public ConfigNoSSM {
|
||||
struct ConfigGemma2B : public ConfigNoCapNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -197,6 +208,10 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
|||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
|
|
@ -251,6 +266,10 @@ struct ConfigGriffin2B {
|
|||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// No SoftCap.
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
static constexpr bool kFFBiases = true;
|
||||
|
|
|
|||
|
|
@ -484,7 +484,11 @@ HWY_NOINLINE void Attention(
|
|||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2 % kSeqLen] = score;
|
||||
}
|
||||
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
||||
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||
if constexpr (TConfig::kAttCap > 0.0f) {
|
||||
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||
}
|
||||
Softmax(head_att, head_att_len);
|
||||
|
||||
// Weighted summation
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||
|
|
@ -979,7 +983,10 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
|
||||
logits, pool);
|
||||
LogitsSoftCap(30.0f, logits, kVocabSize);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(),
|
||||
kVocabSize);
|
||||
}
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(logits, kVocabSize);
|
||||
token = sample_token(logits, kVocabSize);
|
||||
|
|
|
|||
21
gemma/ops.h
21
gemma/ops.h
|
|
@ -1773,24 +1773,9 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
const V vcap = hn::Set(d, cap);
|
||||
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
|
||||
// If we do not subtract the max as in softmax, values > 100 (which do occur)
|
||||
// will all saturate to the cap, and then this function is no longer
|
||||
// monotonic, which would change the results on TopK.
|
||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V vmax = vmin;
|
||||
Foreach(d, x, max_pos, vmin,
|
||||
[&vmax](const auto d, const auto value)
|
||||
HWY_ATTR { vmax = hn::Max(vmax, value); });
|
||||
vmax = hn::MaxOfLanes(d, vmax);
|
||||
|
||||
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
|
||||
// v * vinv_cap + (-vmax*vinv_cap).
|
||||
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
|
||||
|
||||
hn::Transform(
|
||||
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
|
||||
});
|
||||
hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
|
||||
|
|
|
|||
Loading…
Reference in New Issue