Address review comments

This commit is contained in:
Zoltan Szabadka 2024-06-10 15:27:22 +00:00
parent a3a75b77f9
commit 6ca4a8e345
2 changed files with 13 additions and 11 deletions

View File

@ -33,8 +33,9 @@ class WeightInitializer {
template <size_t N>
void operator()(const char* name, CompressedArray<float, N>& tensor) {
float* data = tensor.data();
for (size_t i = 0; i < N; ++i) {
tensor[i] = dist_(gen_);
data[i] = dist_(gen_);
}
tensor.set_scale(1.0f);
}
@ -70,14 +71,18 @@ class AdamUpdater {
CompressedArray<float, kCapacity>& weights,
CompressedArray<float, kCapacity>& grad_m,
CompressedArray<float, kCapacity>& grad_v) {
const float* HWY_RESTRICT g = grad.data();
float* HWY_RESTRICT w = weights.data();
float* HWY_RESTRICT m = grad_m.data();
float* HWY_RESTRICT v = grad_v.data();
for (size_t i = 0; i < kCapacity; ++i) {
grad_m[i] *= beta1_;
grad_m[i] += cbeta1_ * grad[i];
grad_v[i] *= beta2_;
grad_v[i] += cbeta2_ * grad[i] * grad[i];
const float mhat = grad_m[i] * norm1_;
const float vhat = grad_v[i] * norm2_;
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
m[i] *= beta1_;
m[i] += cbeta1_ * g[i];
v[i] *= beta2_;
v[i] += cbeta2_ * g[i] * g[i];
const float mhat = m[i] * norm1_;
const float vhat = v[i] * norm2_;
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
}
}

View File

@ -79,9 +79,6 @@ class CompressedArray {
MatT* data() { return data_.data(); }
const MatT* data() const { return data_.data(); }
MatT& operator[](size_t pos) { return data_[pos]; }
const MatT& operator[](size_t pos) const { return data_[pos]; }
float scale() const { return scale_[0]; }
void set_scale(float scale) { scale_[0] = scale; }