gemma.cpp/util/allocator.h

252 lines
8.6 KiB
C++

// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
#include <stddef.h>
#include <stdint.h>
// IWYU pragma: begin_exports
#include <memory>
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h"
namespace gcpp {
// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The
// `bytes` argument is required for the latter.
using FreeFunc = void (*)(void* mem, size_t bytes);
// Custom deleter for std::unique_ptr that calls `FreeFunc`.
class Deleter {
public:
// `MatStorageT` requires this to be default-constructible.
Deleter() : free_func_(nullptr), bytes_(0) {}
Deleter(FreeFunc free_func, size_t bytes)
: free_func_(free_func), bytes_(bytes) {}
template <typename T>
void operator()(T* p) const {
free_func_(p, bytes_);
}
private:
FreeFunc free_func_;
size_t bytes_;
};
// Unique (move-only) pointer to an aligned array of POD T.
template <typename T>
using AlignedPtr = std::unique_ptr<T[], Deleter>;
// Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
// use `Monostate` (static members).
class Allocator {
public:
// Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread.
static void Init(const BoundedTopology& topology);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing.
static size_t LineBytes();
// Bytes per full vector. Used to compute loop steps.
static size_t VectorBytes();
// Granularity of regions processed by different threads. Their start and
// length of regions should be divisible by this, which is at least
// `HWY_MAX(LineBytes(), VectorBytes())`.
static size_t QuantumBytes();
static size_t L1Bytes();
static size_t L2Bytes();
// Returns pointer aligned to `QuantumBytes()`.
template <typename T>
static AlignedPtr<T> Alloc(size_t num) {
constexpr size_t kSize = sizeof(T);
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
// Fail if the `bytes = num * kSize` computation overflowed.
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
if (check != num) return AlignedPtr<T>();
PtrAndDeleter pd = AllocBytes(bytes);
return AlignedPtr<T>(static_cast<T*>(pd.p), pd.deleter);
}
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
// control over memory placement and multiple packages and NUMA nodes.
static bool ShouldBind();
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
// `p` and `bytes` must be multiples of `QuantumBytes()`.
static bool BindMemory(void* p, size_t bytes, size_t node);
private:
// Type-erased so this can be implemented in allocator.cc.
struct PtrAndDeleter {
void* p;
Deleter deleter;
};
static PtrAndDeleter AllocBytes(size_t bytes);
};
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
// we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols.
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
if (stride == 0) {
stride_ = extents_.cols;
} else {
HWY_ASSERT(stride >= extents_.cols);
stride_ = stride;
}
mem_ = Allocator::Alloc<T>(extents_.rows * stride_);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
size_t Stride() const { return stride_; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
private:
AlignedPtr<T> mem_;
Extents2D extents_;
size_t stride_;
};
// Returns `num` rounded up to an odd number of cache lines. This is used to
// compute strides. An odd number of cache lines prevents 2K aliasing and is
// coprime with the cache associativity, which reduces conflict misses.
template <typename T>
static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
HWY_DASSERT(line_bytes >= 32);
HWY_DASSERT(line_bytes % sizeof(T) == 0);
const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes);
const size_t padded_num = (lines | 1) * line_bytes / sizeof(T);
HWY_DASSERT(padded_num >= num);
return padded_num;
}
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
public:
RowPtr() = default; // for `MMPtrs`.
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
stride_(stride),
step_(static_cast<uint32_t>(
HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))),
cols_(static_cast<uint32_t>(cols)),
row_mask_(Allocator::QuantumBytes() / step_ - 1) {
HWY_DASSERT(stride >= cols);
HWY_DASSERT(row_mask_ != ~size_t{0});
row_mask_ = 0; // TODO: remove
}
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const {
// How much of the previous row's padding to consume.
const size_t pad_bytes = (r & row_mask_) * step_;
HWY_DASSERT(pad_bytes < Allocator::QuantumBytes());
return row0_ + stride_ * r - pad_bytes;
}
size_t Cols() const { return cols_; }
size_t Stride() const { return stride_; }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
// The caller might not have padded enough, so disable the padding in Row().
// Rows will now be exactly `stride` elements apart. This is used when
// writing to the KV cache via MatMul.
row_mask_ = 0;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < cols_);
HWY_DASSERT(cols <= cols_ - c);
return RowPtr<T>(Row(r) + c, cols, stride_);
}
private:
T* HWY_RESTRICT row0_;
size_t stride_;
uint32_t step_; // Copy from Allocator::LineBytes() to improve locality.
uint32_t cols_;
size_t row_mask_;
};
#pragma pack(pop)
using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>;
// For C argument to MatMul.
template <typename T>
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride());
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_