Memory use reduction: smaller/single MMStorage

PiperOrigin-RevId: 804865029
This commit is contained in:
Jan Wassenberg 2025-09-09 05:32:20 -07:00 committed by Copybara-Service
parent 06e5da1e22
commit a5ab99e4ba
7 changed files with 13 additions and 15 deletions

View File

@ -179,7 +179,7 @@ struct Activations {
MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits;
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
// Gated FFW

View File

@ -127,7 +127,7 @@ class QBatch {
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
}

View File

@ -797,9 +797,10 @@ class MMImpl {
return View(A, 0, 0, A.Cols());
} else {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16.
const StridedViewBF A_view =
args.env->storage[args.options.cluster_idx].A(A.Extents());
// support a separate F32 kernel; most A are already BF16. We also only
// have a single MMStorage.
HWY_ASSERT(args.options.cluster_idx == 0);
const StridedViewBF A_view = args.env->storage.A(A.Extents());
DecompressA(A, A_view, args);
return A_view;
}

View File

@ -346,14 +346,10 @@ std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
}
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
// Create storage per cluster. This only applies to in-cluster parallelism.
// For nested and sequential parallelism, a single MMStorage is used.
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
per_cluster.resize(num_clusters);
storage.reserve(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
storage.push_back(MMStorage(ctx.allocator));
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
}

View File

@ -306,11 +306,11 @@ using StridedViewD = StridedView<double>;
class MMStorage {
public:
// Compile-time bounds on matrix columns to enable pre-allocating storage
// and reusing it across `MatMul` calls.
static constexpr size_t kMaxK = 64 * 1024;
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
static constexpr size_t kMaxK = 36 * 1024;
MMStorage(const Allocator& allocator)
// 0.5 GiB. Must be padded, see `DoDecompressA`.
// 288 MiB. Must be padded, see `DoDecompressA`.
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
MatPadding::kOdd) {}
@ -673,7 +673,7 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished.
bool print_best = false;
std::vector<MMStorage> storage;
MMStorage storage;
struct PerCluster {
MMKeys keys;

View File

@ -614,6 +614,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
}
// See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2.
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
const size_t worker,
float temperature = 1.0f) {

View File

@ -30,7 +30,7 @@
namespace gcpp {
// TODO: extend to 16k after updating non_eos.
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };