mirror of https://github.com/google/gemma.cpp.git
Address review comments
This commit is contained in:
parent
a3a75b77f9
commit
6ca4a8e345
|
|
@ -33,8 +33,9 @@ class WeightInitializer {
|
|||
|
||||
template <size_t N>
|
||||
void operator()(const char* name, CompressedArray<float, N>& tensor) {
|
||||
float* data = tensor.data();
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
tensor[i] = dist_(gen_);
|
||||
data[i] = dist_(gen_);
|
||||
}
|
||||
tensor.set_scale(1.0f);
|
||||
}
|
||||
|
|
@ -70,14 +71,18 @@ class AdamUpdater {
|
|||
CompressedArray<float, kCapacity>& weights,
|
||||
CompressedArray<float, kCapacity>& grad_m,
|
||||
CompressedArray<float, kCapacity>& grad_v) {
|
||||
const float* HWY_RESTRICT g = grad.data();
|
||||
float* HWY_RESTRICT w = weights.data();
|
||||
float* HWY_RESTRICT m = grad_m.data();
|
||||
float* HWY_RESTRICT v = grad_v.data();
|
||||
for (size_t i = 0; i < kCapacity; ++i) {
|
||||
grad_m[i] *= beta1_;
|
||||
grad_m[i] += cbeta1_ * grad[i];
|
||||
grad_v[i] *= beta2_;
|
||||
grad_v[i] += cbeta2_ * grad[i] * grad[i];
|
||||
const float mhat = grad_m[i] * norm1_;
|
||||
const float vhat = grad_v[i] * norm2_;
|
||||
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
m[i] *= beta1_;
|
||||
m[i] += cbeta1_ * g[i];
|
||||
v[i] *= beta2_;
|
||||
v[i] += cbeta2_ * g[i] * g[i];
|
||||
const float mhat = m[i] * norm1_;
|
||||
const float vhat = v[i] * norm2_;
|
||||
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -79,9 +79,6 @@ class CompressedArray {
|
|||
MatT* data() { return data_.data(); }
|
||||
const MatT* data() const { return data_.data(); }
|
||||
|
||||
MatT& operator[](size_t pos) { return data_[pos]; }
|
||||
const MatT& operator[](size_t pos) const { return data_[pos]; }
|
||||
|
||||
float scale() const { return scale_[0]; }
|
||||
void set_scale(float scale) { scale_[0] = scale; }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue