Memory use reduction: smaller/single MMStorage

PiperOrigin-RevId: 804865029
This commit is contained in:
Jan Wassenberg 2025-09-09 05:32:20 -07:00 committed by Copybara-Service
parent 06e5da1e22
commit a5ab99e4ba
7 changed files with 13 additions and 15 deletions

View File

@ -179,7 +179,7 @@ struct Activations {
MatStorageT<float> x; // input MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded) MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
// Gated FFW // Gated FFW

View File

@ -127,7 +127,7 @@ class QBatch {
max_size_(max_size), max_size_(max_size),
queries_(queries), queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`. HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0); HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
} }

View File

@ -797,9 +797,10 @@ class MMImpl {
return View(A, 0, 0, A.Cols()); return View(A, 0, 0, A.Cols());
} else { } else {
// Always decompress. To reduce code size/compile time, we no longer // Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16. // support a separate F32 kernel; most A are already BF16. We also only
const StridedViewBF A_view = // have a single MMStorage.
args.env->storage[args.options.cluster_idx].A(A.Extents()); HWY_ASSERT(args.options.cluster_idx == 0);
const StridedViewBF A_view = args.env->storage.A(A.Extents());
DecompressA(A, A_view, args); DecompressA(A, A_view, args);
return A_view; return A_view;
} }

View File

@ -346,14 +346,10 @@ std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)(); return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
} }
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
// Create storage per cluster. This only applies to in-cluster parallelism.
// For nested and sequential parallelism, a single MMStorage is used.
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
per_cluster.resize(num_clusters); per_cluster.resize(num_clusters);
storage.reserve(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
storage.push_back(MMStorage(ctx.allocator));
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
} }

View File

@ -306,11 +306,11 @@ using StridedViewD = StridedView<double>;
class MMStorage { class MMStorage {
public: public:
// Compile-time bounds on matrix columns to enable pre-allocating storage // Compile-time bounds on matrix columns to enable pre-allocating storage
// and reusing it across `MatMul` calls. // and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
static constexpr size_t kMaxK = 64 * 1024; static constexpr size_t kMaxK = 36 * 1024;
MMStorage(const Allocator& allocator) MMStorage(const Allocator& allocator)
// 0.5 GiB. Must be padded, see `DoDecompressA`. // 288 MiB. Must be padded, see `DoDecompressA`.
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
MatPadding::kOdd) {} MatPadding::kOdd) {}
@ -673,7 +673,7 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished. // Whether to print the best config immediately after autotuning finished.
bool print_best = false; bool print_best = false;
std::vector<MMStorage> storage; MMStorage storage;
struct PerCluster { struct PerCluster {
MMKeys keys; MMKeys keys;

View File

@ -614,6 +614,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
} }
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2.
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
const size_t worker, const size_t worker,
float temperature = 1.0f) { float temperature = 1.0f) {

View File

@ -30,7 +30,7 @@
namespace gcpp { namespace gcpp {
// TODO: extend to 16k after updating non_eos. // For hwy::BitSet4096. Note that KVs are extremely large for such batches.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };