From 6ca4a8e345e3babd981e6e172214818fa150f25a Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Mon, 10 Jun 2024 15:27:22 +0000 Subject: [PATCH] Address review comments --- backprop/optimizer.cc | 21 +++++++++++++-------- compression/compress.h | 3 --- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 93d3164..f004446 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -33,8 +33,9 @@ class WeightInitializer { template void operator()(const char* name, CompressedArray& tensor) { + float* data = tensor.data(); for (size_t i = 0; i < N; ++i) { - tensor[i] = dist_(gen_); + data[i] = dist_(gen_); } tensor.set_scale(1.0f); } @@ -70,14 +71,18 @@ class AdamUpdater { CompressedArray& weights, CompressedArray& grad_m, CompressedArray& grad_v) { + const float* HWY_RESTRICT g = grad.data(); + float* HWY_RESTRICT w = weights.data(); + float* HWY_RESTRICT m = grad_m.data(); + float* HWY_RESTRICT v = grad_v.data(); for (size_t i = 0; i < kCapacity; ++i) { - grad_m[i] *= beta1_; - grad_m[i] += cbeta1_ * grad[i]; - grad_v[i] *= beta2_; - grad_v[i] += cbeta2_ * grad[i] * grad[i]; - const float mhat = grad_m[i] * norm1_; - const float vhat = grad_v[i] * norm2_; - weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); + m[i] *= beta1_; + m[i] += cbeta1_ * g[i]; + v[i] *= beta2_; + v[i] += cbeta2_ * g[i] * g[i]; + const float mhat = m[i] * norm1_; + const float vhat = v[i] * norm2_; + w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); } } diff --git a/compression/compress.h b/compression/compress.h index 344cabc..edb7fdb 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -79,9 +79,6 @@ class CompressedArray { MatT* data() { return data_.data(); } const MatT* data() const { return data_.data(); } - MatT& operator[](size_t pos) { return data_[pos]; } - const MatT& operator[](size_t pos) const { return data_[pos]; } - float scale() const { return scale_[0]; } void set_scale(float scale) { scale_[0] = scale; }