// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ #include #include // IWYU pragma: begin_exports #include #include "util/basics.h" #include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: end_exports #include "hwy/aligned_allocator.h" namespace gcpp { // Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The // `bytes` argument is required for the latter. using FreeFunc = void (*)(void* mem, size_t bytes); // Custom deleter for std::unique_ptr that calls `FreeFunc`. class Deleter { public: // `MatStorageT` requires this to be default-constructible. Deleter() : free_func_(nullptr), bytes_(0) {} Deleter(FreeFunc free_func, size_t bytes) : free_func_(free_func), bytes_(bytes) {} template void operator()(T* p) const { free_func_(p, bytes_); } private: FreeFunc free_func_; size_t bytes_; }; // Unique (move-only) pointer to an aligned array of POD T. template using AlignedPtr = std::unique_ptr; // Both allocation, binding, and row accessors depend on the sizes of memory // pages and cache lines. To avoid having to pass `Allocator&` everywhere, we // use `Monostate` (static members). class Allocator { public: // Must be called at least once before any other function. Not thread-safe, // hence only call this from the main thread. static void Init(const BoundedTopology& topology); // Bytes per cache line, or a reasonable guess if unknown. Used to choose // ranges such that there will be no false sharing. static size_t LineBytes(); // Bytes per full vector. Used to compute loop steps. static size_t VectorBytes(); // Granularity of regions processed by different threads. Their start and // length of regions should be divisible by this, which is at least // `HWY_MAX(LineBytes(), VectorBytes())`. static size_t QuantumBytes(); static size_t L1Bytes(); static size_t L2Bytes(); // Returns pointer aligned to `QuantumBytes()`. template static AlignedPtr Alloc(size_t num) { constexpr size_t kSize = sizeof(T); constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; constexpr size_t kBits = hwy::detail::ShiftCount(kSize); static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); const size_t bytes = kIsPow2 ? num << kBits : num * kSize; // Fail if the `bytes = num * kSize` computation overflowed. const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; if (check != num) return AlignedPtr(); PtrAndDeleter pd = AllocBytes(bytes); return AlignedPtr(static_cast(pd.p), pd.deleter); } // Returns whether `BindMemory` can/should be called, i.e. we have page-level // control over memory placement and multiple packages and NUMA nodes. static bool ShouldBind(); // Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is // typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. // Writes zeros to SOME of the memory. Only call if `ShouldBind()`. // `p` and `bytes` must be multiples of `QuantumBytes()`. static bool BindMemory(void* p, size_t bytes, size_t node); private: // Type-erased so this can be implemented in allocator.cc. struct PtrAndDeleter { void* p; Deleter deleter; }; static PtrAndDeleter AllocBytes(size_t bytes); }; // Owns dynamically-allocated aligned memory for a batch of row vectors. // This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns // the memory. template class RowVectorBatch { public: // Default ctor for Activations ctor. RowVectorBatch() = default; // Main ctor, called from Activations::Allocate. If `stride` = 0, the default, // we default to tightly packed rows (`stride = cols`). // WARNING: not all call sites support `stride` != cols. RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) { if (stride == 0) { stride_ = extents_.cols; } else { HWY_ASSERT(stride >= extents_.cols); stride_ = stride; } mem_ = Allocator::Alloc(extents_.rows * stride_); } // Move-only RowVectorBatch(RowVectorBatch&) noexcept = delete; RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; RowVectorBatch(RowVectorBatch&&) noexcept = default; RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; size_t BatchSize() const { return extents_.rows; } size_t Cols() const { return extents_.cols; } size_t Stride() const { return stride_; } Extents2D Extents() const { return extents_; } // Returns the given row vector of length `Cols()`. T* Batch(size_t batch_idx) { HWY_DASSERT(batch_idx < BatchSize()); return mem_.get() + batch_idx * stride_; } const T* Batch(size_t batch_idx) const { HWY_DASSERT(batch_idx < BatchSize()); return mem_.get() + batch_idx * stride_; } // For MatMul or other operations that process the entire batch at once. // TODO: remove once we only use Mat. T* All() { return mem_.get(); } const T* Const() const { return mem_.get(); } size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); } private: AlignedPtr mem_; Extents2D extents_; size_t stride_; }; // Returns `num` rounded up to an odd number of cache lines. This is used to // compute strides. An odd number of cache lines prevents 2K aliasing and is // coprime with the cache associativity, which reduces conflict misses. template static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) { HWY_DASSERT(line_bytes >= 32); HWY_DASSERT(line_bytes % sizeof(T) == 0); const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes); const size_t padded_num = (lines | 1) * line_bytes / sizeof(T); HWY_DASSERT(padded_num >= num); return padded_num; } // Lightweight version of `MatPtr` used for the C argument of `MatMul`, because // it is always float and does not support compressed T, but does support an // arbitrary stride >= cols. #pragma pack(push, 1) // power of two size template class RowPtr { public: RowPtr() = default; // for `MMPtrs`. RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) : row0_(row0), stride_(stride), step_(static_cast( HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))), cols_(static_cast(cols)), row_mask_(Allocator::QuantumBytes() / step_ - 1) { HWY_DASSERT(stride >= cols); HWY_DASSERT(row_mask_ != ~size_t{0}); row_mask_ = 0; // TODO: remove } RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} T* HWY_RESTRICT Row(size_t r) const { // How much of the previous row's padding to consume. const size_t pad_bytes = (r & row_mask_) * step_; HWY_DASSERT(pad_bytes < Allocator::QuantumBytes()); return row0_ + stride_ * r - pad_bytes; } size_t Cols() const { return cols_; } size_t Stride() const { return stride_; } void SetStride(size_t stride) { HWY_DASSERT(stride >= Cols()); stride_ = stride; // The caller might not have padded enough, so disable the padding in Row(). // Rows will now be exactly `stride` elements apart. This is used when // writing to the KV cache via MatMul. row_mask_ = 0; } // Returns 2D subrange whose top-left is `r, c` and width is `cols`. RowPtr View(size_t r, size_t c, size_t cols) const { HWY_DASSERT(c < cols_); HWY_DASSERT(cols <= cols_ - c); return RowPtr(Row(r) + c, cols, stride_); } private: T* HWY_RESTRICT row0_; size_t stride_; uint32_t step_; // Copy from Allocator::LineBytes() to improve locality. uint32_t cols_; size_t row_mask_; }; #pragma pack(pop) using RowPtrBF = RowPtr; using RowPtrF = RowPtr; using RowPtrD = RowPtr; // For C argument to MatMul. template RowPtr RowPtrFromBatch(RowVectorBatch& row_vectors) { return RowPtr(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride()); } } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_