mirror of https://github.com/google/gemma.cpp.git
Add config for att/final cap, skip max-subtract. Fixes #278
Also update includes/deps for backprop/. PiperOrigin-RevId: 648399222
This commit is contained in:
parent
da7507e6f0
commit
e588a7f45d
|
|
@ -361,6 +361,8 @@ cc_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":backprop_scalar",
|
":backprop_scalar",
|
||||||
|
":common",
|
||||||
|
":gemma_lib",
|
||||||
":prompt",
|
":prompt",
|
||||||
":sampler",
|
":sampler",
|
||||||
":weights_raw",
|
":weights_raw",
|
||||||
|
|
@ -382,6 +384,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":backprop",
|
":backprop",
|
||||||
":backprop_scalar",
|
":backprop_scalar",
|
||||||
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
":sampler",
|
":sampler",
|
||||||
|
|
|
||||||
|
|
@ -307,9 +307,9 @@ void LayerVJP(const LayerT<TConfig>& weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
static HWY_NOINLINE void SoftcapVJP(const float cap,
|
||||||
|
const float* HWY_RESTRICT forward,
|
||||||
float* HWY_RESTRICT backward,
|
float* HWY_RESTRICT backward,
|
||||||
const float cap,
|
|
||||||
const size_t size) {
|
const size_t size) {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -318,25 +318,11 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
||||||
const auto one = hn::Set(d, 1.0f);
|
const auto one = hn::Set(d, 1.0f);
|
||||||
const auto vcap = hn::Set(d, cap);
|
const auto vcap = hn::Set(d, cap);
|
||||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||||
|
hn::Transform1(d, backward, size, forward,
|
||||||
// TODO(szabadka): Investigate what to do when the argmax is not unique.
|
|
||||||
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
|
|
||||||
size_t imax = std::max_element(forward, forward + size) - forward;
|
|
||||||
|
|
||||||
hn::Transform1(
|
|
||||||
d, backward, size, forward,
|
|
||||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||||
const auto scaled = hn::Mul(vinv_cap, y);
|
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
|
||||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||||
});
|
});
|
||||||
|
|
||||||
backward[imax] = 0;
|
|
||||||
auto sum = hn::Zero(d);
|
|
||||||
Foreach(d, backward, size, sum,
|
|
||||||
[&sum](const auto d, const auto value) HWY_ATTR {
|
|
||||||
sum = hn::Add(sum, value);
|
|
||||||
});
|
|
||||||
backward[imax] = -hn::ReduceSum(d, sum);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||||
|
|
@ -385,9 +371,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
kVocabSize);
|
kVocabSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
SoftcapVJP(forward.logits.data() + pos * kVocabSize,
|
SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
|
||||||
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize);
|
backward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP<kModelDim, kVocabSize>(
|
MatMulVJP<kModelDim, kVocabSize>(
|
||||||
|
|
|
||||||
|
|
@ -274,20 +274,12 @@ void LayerVJP(const Layer<T, TConfig>& weights,
|
||||||
num_tokens * kModelDim);
|
num_tokens * kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
|
||||||
size_t imax = std::max_element(y, y + N) - y;
|
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||||
T cap = 30.0;
|
|
||||||
T inv_cap = T(1.0) / cap;
|
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
T scaled = y[i] * inv_cap;
|
T scaled = y[i] * inv_cap; // tanh
|
||||||
dy[i] *= (T(1.0) - scaled * scaled);
|
dy[i] *= (T{1.0} - scaled * scaled);
|
||||||
}
|
|
||||||
dy[imax] = T(0.0);
|
|
||||||
for (size_t i = 0; i < N; ++i) {
|
|
||||||
if (i != imax) {
|
|
||||||
dy[imax] -= dy[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -324,10 +316,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
||||||
kVocabSize, num_tokens);
|
kVocabSize, num_tokens);
|
||||||
|
|
||||||
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
for (size_t i = 0; i < num_tokens; ++i) {
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
SoftcapVJPT(forward.logits.data() + i * kVocabSize,
|
SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize,
|
||||||
backward.logits.data() + i * kVocabSize,
|
backward.logits.data() + i * kVocabSize, kVocabSize);
|
||||||
kVocabSize);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJPT(weights.embedder_input_embedding.data(),
|
MatMulVJPT(weights.embedder_input_embedding.data(),
|
||||||
|
|
|
||||||
|
|
@ -16,16 +16,23 @@
|
||||||
#include "backprop/backward_scalar.h"
|
#include "backprop/backward_scalar.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
#include <string.h> // memset
|
#include <string.h> // memset
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
#include <limits>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights_raw.h"
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -202,18 +209,19 @@ TEST(BackPropTest, SoftcapVJP) {
|
||||||
std::array<TC, N> c_x;
|
std::array<TC, N> c_x;
|
||||||
std::array<TC, N> c_y;
|
std::array<TC, N> c_y;
|
||||||
|
|
||||||
|
constexpr float kCap = 30.0f;
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
|
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
|
||||||
Softcap(c_y.data(), N);
|
Softcap(kCap, c_y.data(), N);
|
||||||
return DotT(dy.data(), c_y.data(), N);
|
return DotT(dy.data(), c_y.data(), N);
|
||||||
};
|
};
|
||||||
Softcap(x.data(), N);
|
Softcap(kCap, x.data(), N);
|
||||||
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
||||||
SoftcapVJPT(x.data(), dx.data(), N);
|
SoftcapVJPT(kCap, x.data(), dx.data(), N);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -230,10 +238,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||||
Prompt prompt;
|
Prompt prompt;
|
||||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||||
|
|
||||||
|
const float kCap = 30.0f;
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
prompt.context_size = 1 + (iter % 6);
|
prompt.context_size = 1 + (iter % 6);
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Softcap(x.data(), V * K);
|
Softcap(kCap, x.data(), V * K);
|
||||||
Softmax(x.data(), V, K);
|
Softmax(x.data(), V, K);
|
||||||
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
|
|
@ -368,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestConfig {
|
struct TestConfig : ConfigCapNoSSM {
|
||||||
static constexpr int kSeqLen = 18;
|
static constexpr int kSeqLen = 18;
|
||||||
static constexpr int kVocabSize = 12;
|
static constexpr int kVocabSize = 12;
|
||||||
static constexpr int kModelDim = 32;
|
static constexpr int kModelDim = 32;
|
||||||
|
|
@ -382,12 +391,7 @@ struct TestConfig {
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kConv1dWidth = 0;
|
|
||||||
static constexpr bool kFFBiases = false;
|
|
||||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kGriffinLayers = 0;
|
|
||||||
static constexpr int kNumTensorScales = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(BackPropTest, LayerVJP) {
|
TEST(BackPropTest, LayerVJP) {
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "backprop/backward_scalar.h"
|
#include "backprop/backward_scalar.h"
|
||||||
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights_raw.h"
|
#include "gemma/weights_raw.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -180,7 +182,7 @@ void TestRMSNormVJP() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestConfig {
|
struct TestConfig : public ConfigCapNoSSM {
|
||||||
static constexpr int kSeqLen = 24;
|
static constexpr int kSeqLen = 24;
|
||||||
static constexpr int kVocabSize = 16;
|
static constexpr int kVocabSize = 16;
|
||||||
static constexpr int kModelDim = 32;
|
static constexpr int kModelDim = 32;
|
||||||
|
|
@ -194,12 +196,7 @@ struct TestConfig {
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kConv1dWidth = 0;
|
|
||||||
static constexpr bool kFFBiases = false;
|
|
||||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kGriffinLayers = 0;
|
|
||||||
static constexpr int kNumTensorScales = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void TestEndToEnd() {
|
void TestEndToEnd() {
|
||||||
|
|
|
||||||
|
|
@ -266,8 +266,11 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||||
forward.logits.data() + pos * kVocabSize, pool);
|
forward.logits.data() + pos * kVocabSize, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
|
LogitsSoftCap(TConfig::kFinalCap,
|
||||||
|
forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,6 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
|
||||||
|
|
@ -93,21 +93,11 @@ void Softmax(T* x, size_t N, size_t K) {
|
||||||
Softmax(x + i * N, N);
|
Softmax(x + i * N, N);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void Softcap(T* x, size_t N) {
|
void Softcap(float cap, T* x, size_t N) {
|
||||||
auto maxreal = std::real(x[0]);
|
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||||
size_t imax = 0;
|
|
||||||
for (size_t i = 1; i < N; ++i) {
|
|
||||||
if (std::real(x[i]) > maxreal) {
|
|
||||||
maxreal = std::real(x[i]);
|
|
||||||
imax = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
T cap = 30.0;
|
|
||||||
T inv_cap = T(1.0) / cap;
|
|
||||||
T xmax = x[imax];
|
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
|
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -285,7 +275,10 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
|
Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
|
||||||
|
kVocabSize);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(forward.probs.data(), forward.logits.data(),
|
memcpy(forward.probs.data(), forward.logits.data(),
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <cmath>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
|
||||||
|
|
@ -99,8 +99,19 @@ struct ConfigNoSSM {
|
||||||
static constexpr int kNumTensorScales = 0;
|
static constexpr int kNumTensorScales = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ConfigNoCapNoSSM : ConfigNoSSM {
|
||||||
|
static constexpr float kAttCap = 0.0f;
|
||||||
|
static constexpr float kFinalCap = 0.0f;
|
||||||
|
};
|
||||||
|
|
||||||
|
// For Gemma2 with SoftCap
|
||||||
|
struct ConfigCapNoSSM : ConfigNoSSM {
|
||||||
|
static constexpr float kAttCap = 50.0f;
|
||||||
|
static constexpr float kFinalCap = 30.0f;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemma27B : public ConfigNoSSM {
|
struct ConfigGemma27B : public ConfigCapNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
|
|
@ -120,7 +131,7 @@ struct ConfigGemma27B : public ConfigNoSSM {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemma9B : public ConfigNoSSM {
|
struct ConfigGemma9B : public ConfigCapNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
|
|
@ -140,7 +151,7 @@ struct ConfigGemma9B : public ConfigNoSSM {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemma7B : public ConfigNoSSM {
|
struct ConfigGemma7B : public ConfigNoCapNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
|
|
@ -160,7 +171,7 @@ struct ConfigGemma7B : public ConfigNoSSM {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemma2B : public ConfigNoSSM {
|
struct ConfigGemma2B : public ConfigNoCapNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
|
|
@ -197,6 +208,10 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
static constexpr bool kAbsolutePE = false;
|
static constexpr bool kAbsolutePE = false;
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
|
static constexpr float kAttCap = 0.0f;
|
||||||
|
// This is required for optimize_test to pass.
|
||||||
|
static constexpr float kFinalCap = 30.0f;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
|
|
@ -251,6 +266,10 @@ struct ConfigGriffin2B {
|
||||||
static constexpr bool kAbsolutePE = false;
|
static constexpr bool kAbsolutePE = false;
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
|
// No SoftCap.
|
||||||
|
static constexpr float kAttCap = 0.0f;
|
||||||
|
static constexpr float kFinalCap = 0.0f;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 4;
|
static constexpr int kConv1dWidth = 4;
|
||||||
static constexpr bool kFFBiases = true;
|
static constexpr bool kFFBiases = true;
|
||||||
|
|
|
||||||
|
|
@ -484,7 +484,11 @@ HWY_NOINLINE void Attention(
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2 % kSeqLen] = score;
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||||
|
if constexpr (TConfig::kAttCap > 0.0f) {
|
||||||
|
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||||
|
}
|
||||||
|
Softmax(head_att, head_att_len);
|
||||||
|
|
||||||
// Weighted summation
|
// Weighted summation
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||||
|
|
@ -979,7 +983,10 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||||
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
|
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
|
||||||
logits, pool);
|
logits, pool);
|
||||||
LogitsSoftCap(30.0f, logits, kVocabSize);
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
|
LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(),
|
||||||
|
kVocabSize);
|
||||||
|
}
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(logits, kVocabSize);
|
Softmax(logits, kVocabSize);
|
||||||
token = sample_token(logits, kVocabSize);
|
token = sample_token(logits, kVocabSize);
|
||||||
|
|
|
||||||
19
gemma/ops.h
19
gemma/ops.h
|
|
@ -1773,23 +1773,8 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
const V vcap = hn::Set(d, cap);
|
const V vcap = hn::Set(d, cap);
|
||||||
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||||
|
|
||||||
// If we do not subtract the max as in softmax, values > 100 (which do occur)
|
hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||||
// will all saturate to the cap, and then this function is no longer
|
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
|
||||||
// monotonic, which would change the results on TopK.
|
|
||||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
|
||||||
V vmax = vmin;
|
|
||||||
Foreach(d, x, max_pos, vmin,
|
|
||||||
[&vmax](const auto d, const auto value)
|
|
||||||
HWY_ATTR { vmax = hn::Max(vmax, value); });
|
|
||||||
vmax = hn::MaxOfLanes(d, vmax);
|
|
||||||
|
|
||||||
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
|
|
||||||
// v * vinv_cap + (-vmax*vinv_cap).
|
|
||||||
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
|
|
||||||
|
|
||||||
hn::Transform(
|
|
||||||
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
|
|
||||||
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue