diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 5890190..93e3efe 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -27,6 +27,8 @@ namespace gcpp { namespace { +using MatPtrF = MatPtrT; + // Split into two classes so that ForEachTensor only requires two "other" // arguments. This is anyway useful for locality, because `grad` only feeds // into `grad_m` and `grad_v` here. @@ -40,12 +42,11 @@ class AdamUpdateMV { norm1_(1.0 / (1.0 - std::pow(beta1, t))), norm2_(1.0 / (1.0 - std::pow(beta2, t))) {} - void operator()(const MatPtr& grad, const MatPtr& grad_m, - const MatPtr& grad_v) { + void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) { for (size_t r = 0; r < grad.Rows(); ++r) { - const float* HWY_RESTRICT g = grad.RowT(r); - float* HWY_RESTRICT m = grad_m.MutableRowT(r); - float* HWY_RESTRICT v = grad_v.MutableRowT(r); + const float* HWY_RESTRICT g = grad.Row(r); + float* HWY_RESTRICT m = grad_m.Row(r); + float* HWY_RESTRICT v = grad_v.Row(r); for (size_t c = 0; c < grad.Cols(); ++c) { m[c] *= beta1_; m[c] += cbeta1_ * g[c]; @@ -73,11 +74,12 @@ class AdamUpdateW { norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {} - void operator()(MatPtr& weights, const MatPtr& grad_m, const MatPtr& grad_v) { + void operator()(MatPtrF& weights, const MatPtrF& grad_m, + const MatPtrF& grad_v) { for (size_t r = 0; r < weights.Rows(); ++r) { - float* HWY_RESTRICT w = weights.RowT(r); - const float* HWY_RESTRICT m = grad_m.RowT(r); - const float* HWY_RESTRICT v = grad_v.RowT(r); + float* HWY_RESTRICT w = weights.Row(r); + const float* HWY_RESTRICT m = grad_m.Row(r); + const float* HWY_RESTRICT v = grad_v.Row(r); for (size_t c = 0; c < weights.Cols(); ++c) { const float mhat = m[c] * norm1_; const float vhat = v[c] * norm2_; @@ -100,12 +102,18 @@ void AdamUpdate(ModelWeightsPtrs* grad, float alpha, float beta1, ModelWeightsPtrs* grad_v, hwy::ThreadPool& pool) { AdamUpdateMV update_mv(beta1, beta2, t); grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) { - update_mv(t.mat, *t.other_mat1, *t.other_mat2); + const MatPtrF grad_f(t.mat); + MatPtrF grad_m_f(*t.other_mat1); + MatPtrF grad_v_f(*t.other_mat2); + update_mv(grad_f, grad_m_f, grad_v_f); }); AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t); weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) { - update_w(t.mat, *t.other_mat1, *t.other_mat2); + MatPtrF weights_f(t.mat); + const MatPtrF grad_m_f(*t.other_mat1); + const MatPtrF grad_v_f(*t.other_mat2); + update_w(weights_f, grad_m_f, grad_v_f); }); } diff --git a/gemma/activations.h b/gemma/activations.h index 5f19dd8..489d7f4 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -33,8 +33,7 @@ struct Activations { layer_config(config.layer_configs[0]), seq_len(config.seq_len), cache_pos_size(config.CachePosSize()), - is_griffin(layer_config.type == - LayerAttentionType::kGriffinRecurrentBlock), + is_griffin(config.model == Model::GRIFFIN_2B), x("x", Extents2D(batch_size, config.model_dim), pad_), q("q", @@ -58,21 +57,18 @@ struct Activations { C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), - // No padding for Griffin because it does not always use Row(). griffin_x("griffin_x", is_griffin ? Extents2D(batch_size, config.model_dim) : none_, - MatPadding::kPacked), + pad_), griffin_y("griffin_y", is_griffin ? Extents2D(batch_size, config.model_dim) : none_, - MatPadding::kPacked), + pad_), griffin_gate_x( "griffin_gate_x", - is_griffin ? Extents2D(batch_size, config.model_dim) : none_, - MatPadding::kPacked), + is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_), griffin_multiplier( "griffin_mul", - is_griffin ? Extents2D(batch_size, config.model_dim) : none_, - MatPadding::kPacked), + is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_), inv_timescale(CreateInvTimescale( ThreadingContext::Get().allocator, layer_config.qkv_dim, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 255d859..21e388f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -21,7 +21,6 @@ #include #include // std::min -#include // std::make_unique #include #include "gemma/activations.h" @@ -69,30 +68,44 @@ namespace HWY_NAMESPACE { // Different functions use different naming conventions for the number of // tokens. Functions that are query-independent, such as RMSNorm*, call the // count `num_interleaved`. Functions that are query-dependent, such as -// `Attention`, use separate `num_tokens` and `num_queries`. +// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the +// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. -// TODO: add batch query support for Griffin (QueriesPos). template -HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, - size_t layer, Activations& activations, +HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, + size_t num_tokens, size_t griffin_layer, + Activations& activations, const LayerWeightsPtrs* layer_weights, const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Griffin"); - KVCache& kv_cache = kv_caches[0]; hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; + const D df; + const size_t model_dim = layer_weights->layer_config.model_dim; - const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; + HWY_DASSERT(model_dim % hn::Lanes(df) == 0); + const size_t heads = layer_weights->layer_config.heads; + const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; + HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); + const size_t kHeadDim = model_dim / heads; + const size_t kMatrixSize = kHeadDim * kHeadDim; + + const size_t num_queries = queries_pos.size(); + const hwy::Divisor div_num_q(static_cast(num_queries)); + const size_t num_interleaved = num_tokens * num_queries; // X / Y linear layers. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx); - float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); + // TODO: MatMul + HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows()); + HWY_DASSERT(num_interleaved == activations.griffin_y.Rows()); + for (size_t r = 0; r < num_interleaved; ++r) { + float* HWY_RESTRICT y = activations.griffin_y.Row(r); + float* HWY_RESTRICT x = activations.griffin_x.Row(r); TwoMatVecAdd(layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, - activations.pre_att_rms_out.Row(batch_idx), + activations.pre_att_rms_out.Row(r), /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*out0=*/x, /*out1=*/y, pool); @@ -100,18 +113,19 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, } // Conv1D. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); - HWY_FULL(float) df; - HWY_DASSERT(model_dim % hn::Lanes(df) == 0); + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t query_idx = div_num_q.Remainder(interleaved_idx); + const size_t batch_idx = div_num_q.Divide(interleaved_idx); + const size_t pos = queries_pos[query_idx] + batch_idx; + float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); // cache[i] = input at time t-i. float* HWY_RESTRICT cache[kMaxConv1DWidth]; cache[0] = x; for (size_t i = 1; i < conv_1d_width; i++) { cache[i] = - kv_cache.conv1d_cache.Row(layer) + + kv_caches[query_idx].conv1d_cache.Row(griffin_layer) + ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; } for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { @@ -119,7 +133,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); auto accum1 = hn::Zero(df); - HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); for (size_t l = 0; 2 * l < conv_1d_width; l++) { auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + @@ -136,17 +149,20 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, } // RGLRU - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx); - float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx); - float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx); - float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer); + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t query_idx = div_num_q.Remainder(interleaved_idx); + const size_t batch_idx = div_num_q.Divide(interleaved_idx); + const size_t pos = queries_pos[query_idx] + batch_idx; + + float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); + float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx); + float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx); + float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx); + float* HWY_RESTRICT rnn_state = + kv_caches[query_idx].rglru_cache.Row(griffin_layer); pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - const size_t kHeadDim = model_dim / heads; - const size_t kMatrixSize = kHeadDim * kHeadDim; size_t head_offset = head * kHeadDim; TwoOfsMatVecAddLoop( layer_weights->griffin.gate_w, kMatrixSize * head, @@ -166,7 +182,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, fn_mul); // RNN scan - HWY_FULL(float) df; HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { auto log_a = hn::Load(df, a + head_offset + i); @@ -186,17 +201,18 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, hn::Store(pre_out, df, x + head_offset + i); } }); - } + } // interleaved_idx // Final linear layer. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); - float* out_ptr = activations.att_sums.Row(batch_idx); + // TODO: MatMul + for (size_t r = 0; r < num_interleaved; ++r) { + float* HWY_RESTRICT x = activations.griffin_x.Row(r); + float* out_ptr = activations.att_sums.Row(r); MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr, pool); } -} +} // GriffinRecurrent // Wrapper class; holds arguments in member variables to shorten call sites. template @@ -219,11 +235,7 @@ class GemmaAttention { activations_.weights_config.attention_window_sizes[layer] == activations_.seq_len; // TODO: add a config flag instead of hardcoding the model. - if (is_global_layer && - (activations_.weights_config.model == Model::GEMMA3_4B || - activations_.weights_config.model == Model::GEMMA3_12B || - activations_.weights_config.model == Model::GEMMA3_27B || - activations_.weights_config.model == Model::GEMMA3_1B)) { + if (is_global_layer && IsVLM(activations_.weights_config.model)) { inv_timescale = activations_.inv_timescale_global.Packed(); } // PostQKType::Rope @@ -454,7 +466,6 @@ class GemmaAttention { SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos, start_pos, last_pos); - }); } @@ -587,14 +598,12 @@ HWY_NOINLINE void Attention( GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer, activations, layer_weights, div_seq_len, kv_caches)(); } else { - // Only reached if the model is Griffin. - // The kv_caches are allocated only for the griffin layers, so we need to - // map the layer index to the griffin layer index. - auto type = layer_weights->layer_config.type; - size_t layer_of_type = + HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); + // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, + // so map `layer` to the Griffin layer index. + const size_t griffin_layer = activations.weights_config.NumLayersOfTypeBefore(type, layer); - HWY_ASSERT(queries_pos.size() == 1); - GriffinRecurrent(queries_pos[0], num_tokens, layer_of_type, activations, + GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, layer_weights, kv_caches); } } @@ -1056,7 +1065,6 @@ HWY_NOINLINE void Prefill( // threads to parallelizing over queries, but for simplicity we assign them // all to MatMul. const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; - HWY_DASSERT(max_tbatch_size <= activations.x.Rows()); // For each query. `qi` is within the batch, not the global query index. for (size_t qi = 0; qi < num_queries; ++qi) { @@ -1397,6 +1405,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, const QueriesPos& queries_prefix_end, const size_t query_idx_start, const KVCaches& kv_caches, TimingInfo& timing_info) { + HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); // Griffin assumes that the recurrent block cache is zero-initialized. for (size_t i = 0; i < kv_caches.size(); ++i) { if (queries_pos_in[i] == 0) { @@ -1510,17 +1519,10 @@ void GenerateBatchT(const ModelStore& model, const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(kv_caches.size() >= num_queries); - // Griffin does not support query batching. - size_t max_qbatch_size = runtime_config.decode_qbatch_size; - for (const LayerConfig& layer_config : model.Config().layer_configs) { - if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { - max_qbatch_size = 1; - break; - } - } - + const size_t max_qbatch_size = runtime_config.decode_qbatch_size; const size_t max_batch_size = HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); + Activations activations(model.Config(), max_batch_size, env); for (size_t qbatch_start = 0; qbatch_start < num_queries; @@ -1528,7 +1530,6 @@ void GenerateBatchT(const ModelStore& model, // Generate one batch of tokens from `qbatch_size` queries. const size_t qbatch_size = HWY_MIN(num_queries - qbatch_start, max_qbatch_size); - activations.SetBatchSize(qbatch_size); const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 3992270..49dc31c 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -25,8 +25,9 @@ namespace gcpp { void KVCache::ZeroGriffinCache() { - if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache); - if (rglru_cache.HasPtr()) ZeroInit(rglru_cache); + if (griffin_layers == 0) return; + ZeroInit(conv1d_cache); + ZeroInit(rglru_cache); } static size_t GriffinConv1dCols(const ModelConfig& config) { @@ -34,7 +35,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) { for (const auto& layer_config : config.layer_configs) { conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); } - return conv1d_width == 0 ? 0 : conv1d_width - 1; + // The row offset, in blocks of model_dim is computed mod (conv1d_width - 1), + // hence allocate conv1d_width * model_dim total columns. + return conv1d_width * config.model_dim; } // prefill_tbatch_size is the maximum number of tokens from one query to @@ -42,12 +45,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) { KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size) : griffin_layers( config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)), - griffin_conv1d_cols(GriffinConv1dCols(config)), - // TODO(patrickms): Add query batching support for Griffin. - conv1d_cache( - "conv1d_cache", - Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim), - MatPadding::kOdd), + conv1d_cache("conv1d_cache", + Extents2D(griffin_layers, GriffinConv1dCols(config)), + MatPadding::kOdd), rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim), MatPadding::kOdd) { // TODO: move to MatStorageT. diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index f9707c8..014e75d 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -31,7 +31,6 @@ struct KVCache { KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); size_t griffin_layers = 0; - size_t griffin_conv1d_cols = 0; // griffin_layers, griffin_conv1d_cols * config.model_dim MatStorageT conv1d_cache; MatStorageT rglru_cache; // griffin_layers, config.model_dim diff --git a/gemma/weights.cc b/gemma/weights.cc index 0269c60..4e8325d 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -97,21 +97,27 @@ void LayerWeightsPtrs::Fixup(std::vector& mat_owners) { SplitW1NUQ(layer_config); } +struct TensorToRead { + MatPtr* mat; + BlobRange range; + // Some tensors opt out of padding via kNoPad flags. + MatPadding padding; +}; + // Allocates multiple in parallel and binds to NUMA nodes. -static void AllocateAndBindAll(const std::vector& mats, - MatPadding padding, +static void AllocateAndBindAll(const std::vector& tensors, std::vector& owners, hwy::ThreadPool& pool) { const size_t start = owners.size(); - owners.resize(start + mats.size()); + owners.resize(start + tensors.size()); MMParallel parallel(ThreadingContext::Get()); // Allocate in parallel because faulting in large tensors is slow. - pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) { - owners[start + task].AllocateFor(*mats[task], padding); + pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { + owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding); // TODO(janwas): MatMul outputs will later also be BF16. - BindB(*mats[task], sizeof(float), parallel); + BindB(*tensors[task].mat, sizeof(float), parallel); }); } @@ -155,55 +161,54 @@ MapPtr MapFileOrNull(File& file, uint64_t file_bytes) { return MapPtr(); } -static void MapAll(const std::vector& mats, - const std::vector& ranges, const MapPtr& mapped) { +static void MapAll(const std::vector& tensors, + const MapPtr& mapped) { PROFILER_ZONE("Startup.Weights.Map"); - for (size_t i = 0; i < mats.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { // SetPtr does not change the stride, but it is expected to be packed // because that is what Compress() writes to the file. - const size_t mat_bytes = mats[i]->PackedBytes(); + const size_t mat_bytes = tensors[i].mat->PackedBytes(); // Ensure blob size matches that computed from metadata. - HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name()); + HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name()); - mats[i]->SetPtr(const_cast(mapped.get() + ranges[i].offset), - mats[i]->Stride()); + tensors[i].mat->SetPtr( + const_cast(mapped.get() + tensors[i].range.offset), + tensors[i].mat->Stride()); } } -std::vector MakeBatches(const std::vector& ranges, - const std::vector& mats, +std::vector MakeBatches(const std::vector& tensors, const uint64_t file_bytes) { PROFILER_ZONE("Startup.Weights.MakeBatches"); // Batches must be contiguous but blobs are padded, hence at least one // batch per tensor, and more when tensor rows exceed the batch size. std::vector batches; - batches.reserve(mats.size()); + batches.reserve(tensors.size()); - for (size_t i = 0; i < mats.size(); ++i) { - uint64_t offset = ranges[i].offset; - HWY_ASSERT(ranges[i].End() <= file_bytes); + for (size_t i = 0; i < tensors.size(); ++i) { + const BlobRange& range = tensors[i].range; + MatPtr& mat = *tensors[i].mat; + uint64_t offset = range.offset; + HWY_ASSERT(range.End() <= file_bytes); - batches.emplace_back(offset, ranges[i].key_idx); - const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes(); - // Caution, `RowT` requires knowledge of the actual type. We instead use - // the first row, which is the same for any type, and advance the *byte* - // pointer by the *byte* stride. - const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes(); - uint8_t* row = mats[i]->RowT(0); - for (size_t r = 0; r < mats[i]->Rows(); ++r) { - if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch. - batches.emplace_back(offset, ranges[i].key_idx); + batches.emplace_back(offset, range.key_idx); + const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes(); + const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes(); + uint8_t* row_bytes = mat.RowBytes(0); + for (size_t r = 0; r < mat.Rows(); ++r) { + if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch. + batches.emplace_back(offset, range.key_idx); // Adding to an empty batch is always successful. - HWY_ASSERT(batches.back().Add(row, file_bytes_per_row)); + HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); } offset += file_bytes_per_row; - row += mem_stride_bytes; + row_bytes += mem_stride_bytes; // Keep the in-memory row padding uninitialized so msan detects any use. } - HWY_ASSERT(offset == ranges[i].End()); + HWY_ASSERT(offset == range.End()); } - HWY_ASSERT(batches.size() >= mats.size()); + HWY_ASSERT(batches.size() >= tensors.size()); return batches; } @@ -228,16 +233,14 @@ static void ReadBatches(const BlobReader& reader, } // Aborts on error. -static void MapOrReadAll(const std::vector& mats, BlobReader& reader, - const std::vector& ranges, Tristate map, +static void MapOrReadAll(const std::vector& tensors, + BlobReader& reader, Tristate map, std::vector& mat_owners, - const MatPadding padding, hwy::ThreadPool& pool) { - HWY_ASSERT(mats.size() == ranges.size()); - + hwy::ThreadPool& pool) { if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) { MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes()); if (mapped) { - MapAll(mats, ranges, mapped); + MapAll(tensors, mapped); return; } } // otherwise fall through to read mode @@ -245,32 +248,32 @@ static void MapOrReadAll(const std::vector& mats, BlobReader& reader, { PROFILER_ZONE("Startup.Weights.Allocate"); // NOTE: this changes the stride of `mats`! - AllocateAndBindAll(mats, padding, mat_owners, pool); + AllocateAndBindAll(tensors, mat_owners, pool); } const std::vector batches = - MakeBatches(ranges, mats, reader.file_bytes()); + MakeBatches(tensors, reader.file_bytes()); ReadBatches(reader, batches, pool); } void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map, hwy::ThreadPool& pool) { // List of tensors to read/map, and where from. - std::vector mats; - std::vector ranges; - - // Padding is inserted when reading row by row, except for NUQ tensors. - const MatPadding padding = MatPadding::kOdd; + std::vector tensors; AllocatePointer(model.Config()); // Enumerate all weights (negligible cost). CallT([&](const auto& weights) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { + const MatPadding padding = (t.flags & TensorArgs::kNoPad) + ? MatPadding::kPacked + : MatPadding::kOdd; size_t key_idx; if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { - mats.push_back(&t.mat); - ranges.push_back(reader.Range(key_idx)); + tensors.push_back({.mat = &t.mat, + .range = reader.Range(key_idx), + .padding = padding}); return; } if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found. @@ -278,7 +281,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, }); }); - MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool); + MapOrReadAll(tensors, reader, map, mat_owners_, pool); Fixup(pool); } @@ -338,9 +341,9 @@ void WeightsOwner::LogWeightStatsF32() { printf("[scale=%f] ", t.mat.Scale()); } hwy::Stats stats; - HWY_ASSERT(t.mat.GetType() == Type::kF32); + const MatPtrT mat_f(t.mat); for (size_t r = 0; r < t.mat.Rows(); ++r) { - const float* HWY_RESTRICT row = t.mat.RowT(r); + const float* HWY_RESTRICT row = mat_f.Row(r); for (size_t c = 0; c < t.mat.Cols(); ++c) { stats.Notify(row[c]); } diff --git a/gemma/weights.h b/gemma/weights.h index aa184db..b0b24c2 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -43,24 +43,27 @@ struct TensorArgs { // name/type from another `LayerWeightsPtrs` for iterating over tensor pairs // (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`. // `flags` is a combination of zero or more `Flags`. - TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2, - int flags) + TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags) : mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) {} MatPtr& mat; - const MatPtr* other_mat1; // either/both can be nullptr. - const MatPtr* other_mat2; + MatPtr* other_mat1; // either/both can be nullptr. + MatPtr* other_mat2; - // TODO: freestanding enum class instead? These are mutually exclusive. enum Flags { - // Read the tensor from the file and abort if it is not found. + // Default: Read the tensor from the file and abort if it is not found. kMustRead = 0, + // Not an error if the tensor is not present in the file. For example, // the _w1/_w2 tensors are not always present. kMaybeRead = 1, + + // Avoid padding tensor rows when reading. Used for some Griffin tensors + // whose index computations do not use Row() accessors. + kNoPad = 2, }; const int flags; }; @@ -214,8 +217,8 @@ struct LayerWeightsPtrs { // can also iterate over pairs or triples of tensors for `AdamUpdateMV`. // Public because also called by `WeightsPtrs`. template - void ForEachTensor(const LayerWeightsPtrs* other1, - const LayerWeightsPtrs* other2, Func func) { + void ForEachTensor(LayerWeightsPtrs* other1, + LayerWeightsPtrs* other2, Func func) { if (layer_config.type == LayerAttentionType::kVit) { // MHA. func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); @@ -248,9 +251,9 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); - func(TENSOR_ARGS(griffin.conv_w, kMustRead)); + func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad)); func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); - func(TENSOR_ARGS(griffin.gate_w, kMustRead)); + func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad)); func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); func(TENSOR_ARGS(griffin.a, kMustRead)); } @@ -363,8 +366,8 @@ struct LayerWeightsPtrs { // For FFN. Fast, only updates pointers. void SplitW1() { - // We only use this tensor for Gemma layers. - if (layer_config.type != LayerAttentionType::kGemma) return; + // Used for Gemma and Griffin layers; FFWVit uses different tensors. + if (layer_config.type == LayerAttentionType::kVit) return; // Files have both or neither of w1 and w2, and backprop/ allocates both. HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); @@ -514,10 +517,10 @@ struct ModelWeightsPtrs { // used to copy from another set of weights. Public because called by tests // and `WeightsOwner`. template - void ForEachTensor(const ModelWeightsPtrs* other1, - const ModelWeightsPtrs* other2, Func func) { - const LayerWeightsPtrs* other_layer1 = nullptr; - const LayerWeightsPtrs* other_layer2 = nullptr; + void ForEachTensor(ModelWeightsPtrs* other1, + ModelWeightsPtrs* other2, Func func) { + LayerWeightsPtrs* other_layer1 = nullptr; + LayerWeightsPtrs* other_layer2 = nullptr; func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); func(TENSOR_ARGS(final_norm_scale, kMustRead)); @@ -569,11 +572,12 @@ struct ModelWeightsPtrs { // Copies only the allocated tensors in `*this` from tensors in `other`. void CopyFrom(const ModelWeightsPtrs& other) { - ForEachTensor(&other, nullptr, [](const TensorArgs& t) { - if (!t.mat.HasPtr()) return; - HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); - CopyMat(*t.other_mat1, t.mat); - }); + ForEachTensor(const_cast*>(&other), nullptr, + [](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; + HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); + CopyMat(*t.other_mat1, t.mat); + }); } // Instead of reading, only allocates memory for all tensors. Used by diff --git a/io/blob_store.cc b/io/blob_store.cc index d0a0edb..6328a87 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -84,6 +84,14 @@ struct Header { // standard layout class static_assert(sizeof(Header) == 16); } // namespace +// A write I/O request, each serviced by one thread in a pool. +struct BlobIO { + BlobIO(BlobRange range, void* data) : range(range), data(data) {} + + BlobRange range; + void* data; // Read-only for writes. +}; + // Little-endian on-disk representation: a fixed-size `Header`, then a padded // variable-length 'directory' of blob keys and their offset/sizes, then the // 'payload' of each blob's data with padding in between, followed by padding to @@ -238,11 +246,11 @@ class BlobStore { return true; // all OK } - void EnqueueWriteForHeaderAndDirectory(std::vector& writes) const { + void EnqueueWriteForHeaderAndDirectory(std::vector& writes) const { const size_t key_idx = 0; // not actually associated with a key/blob writes.emplace_back( BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx}, - // members are const and BlobIO2 requires non-const pointers, and they + // members are const and BlobIO requires non-const pointers, and they // are not modified by file writes. const_cast(&header_)); writes.emplace_back( @@ -314,7 +322,7 @@ BlobReader::BlobReader(const Path& blob_path) // Split into chunks for load-balancing even if blob sizes vary. static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, - uint8_t* data, std::vector& writes) { + uint8_t* data, std::vector& writes) { constexpr size_t kChunkBytes = 4 * 1024 * 1024; const uint64_t end = offset + bytes; // Split into whole chunks and possibly one remainder. @@ -336,7 +344,7 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, static void EnqueueWritesForBlobs(const BlobStore& bs, const hwy::Span blobs[], std::vector& zeros, - std::vector& writes) { + std::vector& writes) { // All-zero buffer used to write padding to the file without copying the // input blobs. static constexpr uint8_t kZeros[kBlobAlign] = {0}; @@ -388,7 +396,7 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { HWY_ASSERT(num_blobs != 0); HWY_ASSERT(num_blobs == blobs_.size()); - std::vector writes; + std::vector writes; writes.reserve(16384); const BlobStore bs(num_blobs, keys_.data(), blobs_.data()); diff --git a/io/blob_store.h b/io/blob_store.h index 78bc08e..19c6639 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -44,16 +44,6 @@ struct BlobRange { size_t key_idx; }; -// A read or write I/O request, each serviced by one thread in a pool. -struct BlobIO2 { - BlobIO2(BlobRange range, void* data) : range(range), data(data) {} - - BlobRange range; - void* data; // Modified only if a read request. Read-only for writes. -}; - -class BlobStore; - // Reads `BlobStore` header, converts keys to strings and creates a hash map for // faster lookups. // TODO(janwas): rename to BlobFinder or similar. diff --git a/ops/matmul.cc b/ops/matmul.cc index 2f2f795..18af14e 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -425,7 +425,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) have_timer_stop = hwy::platform::HaveTimerStop(cpu100); } -void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { +void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { Allocator& allocator = parallel.allocator(); if (!allocator.ShouldBind()) return; if (B.Rows() == 1) return; @@ -437,8 +437,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); const size_t node = parallel.Node(pkg_idx); - uintptr_t begin = - reinterpret_cast(B.RowT(rows_b.begin())); + uintptr_t begin = reinterpret_cast(B.RowBytes(rows_b.begin())); uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); // B row padding is less than the page size, so only bind the subset that // is page-aligned. @@ -451,7 +450,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { } // C is BF16/float, or double for partial -void BindC(const MatPtr& C, MMParallel& parallel) { +void BindC(MatPtr& C, MMParallel& parallel) { Allocator& allocator = parallel.allocator(); if (!allocator.ShouldBind()) return; @@ -470,8 +469,7 @@ void BindC(const MatPtr& C, MMParallel& parallel) { const size_t node = parallel.Node(pkg_idx); for (size_t im = 0; im < C.Rows(); ++im) { - ok &= allocator.BindMemory(C.MutableRowT(im) + begin, - end - begin, node); + ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); } } if (HWY_UNLIKELY(!ok)) { diff --git a/ops/matmul.h b/ops/matmul.h index fc4e8c1..c2474cb 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -172,9 +172,9 @@ class MMParallel { ThreadingContext& ctx_; }; -void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel); +void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); // C is BF16/float, or double for partial. -void BindC(const MatPtr& C, MMParallel& parallel); +void BindC(MatPtr& C, MMParallel& parallel); // Per-package storage for packed A, and one global C-shaped `partial` for // accumulating partial dot products (sections of K). diff --git a/util/mat.cc b/util/mat.cc index 86baaee..28763ba 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -41,8 +41,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) { } const size_t row_bytes = to.Cols() * to.ElementBytes(); for (size_t r = 0; r < to.Rows(); ++r) { - const uint8_t* from_row = from.RowT(r); - uint8_t* to_row = to.RowT(r); + const uint8_t* from_row = from.RowBytes(r); + uint8_t* to_row = to.RowBytes(r); hwy::CopyBytes(from_row, to_row, row_bytes); } } @@ -58,7 +58,7 @@ void ZeroInit(MatPtr& mat) { } const size_t row_bytes = mat.Cols() * mat.ElementBytes(); for (size_t r = 0; r < mat.Rows(); ++r) { - hwy::ZeroBytes(mat.RowT(r), row_bytes); + hwy::ZeroBytes(mat.RowBytes(r), row_bytes); } } @@ -71,15 +71,19 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) { std::normal_distribution dist(0.0, stddev); if (mat.GetType() == Type::kF32) { + MatPtrT mat_f(mat); + for (size_t r = 0; r < mat.Rows(); ++r) { - float* HWY_RESTRICT row = mat.RowT(r); + float* HWY_RESTRICT row = mat_f.Row(r); for (size_t c = 0; c < mat.Cols(); ++c) { row[c] = dist(gen); } } } else { + MatPtrT mat_d(mat); + for (size_t r = 0; r < mat.Rows(); ++r) { - double* HWY_RESTRICT row = mat.RowT(r); + double* HWY_RESTRICT row = mat_d.Row(r); for (size_t c = 0; c < mat.Cols(); ++c) { row[c] = dist(gen); } diff --git a/util/mat.h b/util/mat.h index da662c9..0314e20 100644 --- a/util/mat.h +++ b/util/mat.h @@ -91,21 +91,14 @@ class MatPtr : public IFields { return num_elements_ * element_bytes_; } - // Works for any kind of padding. - template - T* MutableRowT(size_t row) const { + // Works for any kind of padding and element type. + uint8_t* RowBytes(size_t row) { HWY_DASSERT(row < Rows()); - return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; + return static_cast(ptr_) + row * (stride_ * element_bytes_); } - template - T* RowT(size_t row) { + const uint8_t* RowBytes(size_t row) const { HWY_DASSERT(row < Rows()); - return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; - } - template - const T* RowT(size_t row) const { - HWY_DASSERT(row < Rows()); - return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_; + return static_cast(ptr_) + row * (stride_ * element_bytes_); } Type GetType() const { return type_; } @@ -209,9 +202,8 @@ class MatPtr : public IFields { float scale_ = 1.0f; // multiplier for each value, for MatMul. }; -// Non-type erased version of `MatPtr`. Although `MatPtr` also provides -// type-aware accessors (`RowT`), this class is more convenient when accessing -// elements, and ensures the template argument and `Type` are consistent. +// Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures +// the template argument and `Type` are consistent. template class MatPtrT : public MatPtr { public: @@ -227,7 +219,9 @@ class MatPtrT : public MatPtr { : MatPtrT(name.c_str(), ExtentsFromInfo(info.Find(name))) {} // Copying allowed because the metadata is small. - MatPtrT(const MatPtr& other) : MatPtr(other) {} + MatPtrT(const MatPtr& other) : MatPtr(other) { + HWY_ASSERT(other.GetType() == TypeEnum()); + } MatPtrT& operator=(const MatPtr& other) { MatPtr::operator=(other); return *this; @@ -252,22 +246,24 @@ class MatPtrT : public MatPtr { return Packed(); } - const MatT* Row(size_t row) const { return this->RowT(row); } - MatT* Row(size_t row) { return this->RowT(row); } + MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); } + const MatT* Row(size_t row) const { + return HWY_RCAST_ALIGNED(const T*, RowBytes(row)); + } PackedSpan PaddedSpan() const { - return MakeConstSpan(Row(0), Rows() * Stride()); + return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride()); } // For `compress-inl.h` functions, which assume contiguous streams and thus // require packed layout. - PackedSpan Span() const { - HWY_ASSERT(IsPacked()); - return MakeConstSpan(Row(0), num_elements_); - } PackedSpan Span() { HWY_ASSERT(IsPacked()); - return MakeSpan(Row(0), num_elements_); + return MakeSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_); + } + PackedSpan Span() const { + HWY_ASSERT(IsPacked()); + return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_); } };