Fix RowT issue and improve Griffin (currently still broken)

Use type-safe MatPtrT via dynamic_cast, avoid/remove unsafe RowT
activations: Griffin tensors are now padded
Griffin: add batching support, fix conv1d_cache allocation
weights: bundle to TensorToRead, add kNoPad flag, fix SplitW1
const-correct fix for ForEachTensor
blob_store: move BlobIO2 to .cc and rename BlobIO
PiperOrigin-RevId: 760610094
This commit is contained in:
Jan Wassenberg 2025-05-19 07:01:29 -07:00 committed by Copybara-Service
parent d6cfabc2c1
commit cb188d4a0e
13 changed files with 218 additions and 211 deletions

View File

@ -27,6 +27,8 @@ namespace gcpp {
namespace { namespace {
using MatPtrF = MatPtrT<float>;
// Split into two classes so that ForEachTensor only requires two "other" // Split into two classes so that ForEachTensor only requires two "other"
// arguments. This is anyway useful for locality, because `grad` only feeds // arguments. This is anyway useful for locality, because `grad` only feeds
// into `grad_m` and `grad_v` here. // into `grad_m` and `grad_v` here.
@ -40,12 +42,11 @@ class AdamUpdateMV {
norm1_(1.0 / (1.0 - std::pow(beta1, t))), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {} norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
void operator()(const MatPtr& grad, const MatPtr& grad_m, void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) {
const MatPtr& grad_v) {
for (size_t r = 0; r < grad.Rows(); ++r) { for (size_t r = 0; r < grad.Rows(); ++r) {
const float* HWY_RESTRICT g = grad.RowT<float>(r); const float* HWY_RESTRICT g = grad.Row(r);
float* HWY_RESTRICT m = grad_m.MutableRowT<float>(r); float* HWY_RESTRICT m = grad_m.Row(r);
float* HWY_RESTRICT v = grad_v.MutableRowT<float>(r); float* HWY_RESTRICT v = grad_v.Row(r);
for (size_t c = 0; c < grad.Cols(); ++c) { for (size_t c = 0; c < grad.Cols(); ++c) {
m[c] *= beta1_; m[c] *= beta1_;
m[c] += cbeta1_ * g[c]; m[c] += cbeta1_ * g[c];
@ -73,11 +74,12 @@ class AdamUpdateW {
norm2_(1.0 / (1.0 - std::pow(beta2, t))), norm2_(1.0 / (1.0 - std::pow(beta2, t))),
epsilon_(epsilon) {} epsilon_(epsilon) {}
void operator()(MatPtr& weights, const MatPtr& grad_m, const MatPtr& grad_v) { void operator()(MatPtrF& weights, const MatPtrF& grad_m,
const MatPtrF& grad_v) {
for (size_t r = 0; r < weights.Rows(); ++r) { for (size_t r = 0; r < weights.Rows(); ++r) {
float* HWY_RESTRICT w = weights.RowT<float>(r); float* HWY_RESTRICT w = weights.Row(r);
const float* HWY_RESTRICT m = grad_m.RowT<float>(r); const float* HWY_RESTRICT m = grad_m.Row(r);
const float* HWY_RESTRICT v = grad_v.RowT<float>(r); const float* HWY_RESTRICT v = grad_v.Row(r);
for (size_t c = 0; c < weights.Cols(); ++c) { for (size_t c = 0; c < weights.Cols(); ++c) {
const float mhat = m[c] * norm1_; const float mhat = m[c] * norm1_;
const float vhat = v[c] * norm2_; const float vhat = v[c] * norm2_;
@ -100,12 +102,18 @@ void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) { ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
AdamUpdateMV update_mv(beta1, beta2, t); AdamUpdateMV update_mv(beta1, beta2, t);
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) { grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
update_mv(t.mat, *t.other_mat1, *t.other_mat2); const MatPtrF grad_f(t.mat);
MatPtrF grad_m_f(*t.other_mat1);
MatPtrF grad_v_f(*t.other_mat2);
update_mv(grad_f, grad_m_f, grad_v_f);
}); });
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t); AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) { weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
update_w(t.mat, *t.other_mat1, *t.other_mat2); MatPtrF weights_f(t.mat);
const MatPtrF grad_m_f(*t.other_mat1);
const MatPtrF grad_v_f(*t.other_mat2);
update_w(weights_f, grad_m_f, grad_v_f);
}); });
} }

View File

@ -33,8 +33,7 @@ struct Activations {
layer_config(config.layer_configs[0]), layer_config(config.layer_configs[0]),
seq_len(config.seq_len), seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()), cache_pos_size(config.CachePosSize()),
is_griffin(layer_config.type == is_griffin(config.model == Model::GRIFFIN_2B),
LayerAttentionType::kGriffinRecurrentBlock),
x("x", Extents2D(batch_size, config.model_dim), pad_), x("x", Extents2D(batch_size, config.model_dim), pad_),
q("q", q("q",
@ -58,21 +57,18 @@ struct Activations {
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
// No padding for Griffin because it does not always use Row().
griffin_x("griffin_x", griffin_x("griffin_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked), pad_),
griffin_y("griffin_y", griffin_y("griffin_y",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked), pad_),
griffin_gate_x( griffin_gate_x(
"griffin_gate_x", "griffin_gate_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
MatPadding::kPacked),
griffin_multiplier( griffin_multiplier(
"griffin_mul", "griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
MatPadding::kPacked),
inv_timescale(CreateInvTimescale( inv_timescale(CreateInvTimescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim, ThreadingContext::Get().allocator, layer_config.qkv_dim,

View File

@ -21,7 +21,6 @@
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::min #include <algorithm> // std::min
#include <memory> // std::make_unique
#include <vector> #include <vector>
#include "gemma/activations.h" #include "gemma/activations.h"
@ -69,30 +68,44 @@ namespace HWY_NAMESPACE {
// Different functions use different naming conventions for the number of // Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the // tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as // count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. // `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
// TODO: add batch query support for Griffin (QueriesPos).
template <typename T> template <typename T>
HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
size_t layer, Activations& activations, size_t num_tokens, size_t griffin_layer,
Activations& activations,
const LayerWeightsPtrs<T>* layer_weights, const LayerWeightsPtrs<T>* layer_weights,
const KVCaches& kv_caches) { const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin"); PROFILER_ZONE("Gen.Griffin");
KVCache& kv_cache = kv_caches[0];
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
const D df;
const size_t model_dim = layer_weights->layer_config.model_dim; const size_t model_dim = layer_weights->layer_config.model_dim;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t heads = layer_weights->layer_config.heads; const size_t heads = layer_weights->layer_config.heads;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_queries = queries_pos.size();
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
const size_t num_interleaved = num_tokens * num_queries;
// X / Y linear layers. // X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { // TODO: MatMul
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx); HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
TwoMatVecAdd(layer_weights->griffin.linear_x_w, TwoMatVecAdd(layer_weights->griffin.linear_x_w,
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(batch_idx), activations.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool); /*out0=*/x, /*out1=*/y, pool);
@ -100,18 +113,19 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
} }
// Conv1D. // Conv1D.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
const size_t pos = batch_start + batch_idx; ++interleaved_idx) {
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); const size_t query_idx = div_num_q.Remainder(interleaved_idx);
HWY_FULL(float) df; const size_t batch_idx = div_num_q.Divide(interleaved_idx);
HWY_DASSERT(model_dim % hn::Lanes(df) == 0); const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
// cache[i] = input at time t-i. // cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth]; float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x; cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) { for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] = cache[i] =
kv_cache.conv1d_cache.Row(layer) + kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
} }
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
@ -119,7 +133,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
auto accum0 = auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df); auto accum1 = hn::Zero(df);
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < conv_1d_width; l++) { for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 = auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
@ -136,17 +149,20 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
} }
// RGLRU // RGLRU
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
const size_t pos = batch_start + batch_idx; ++interleaved_idx) {
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx); const size_t query_idx = div_num_q.Remainder(interleaved_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); const size_t batch_idx = div_num_q.Divide(interleaved_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx); const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer); float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
float* HWY_RESTRICT rnn_state =
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim; size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop( TwoOfsMatVecAddLoop(
layer_weights->griffin.gate_w, kMatrixSize * head, layer_weights->griffin.gate_w, kMatrixSize * head,
@ -166,7 +182,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul); fn_mul);
// RNN scan // RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i); auto log_a = hn::Load(df, a + head_offset + i);
@ -186,17 +201,18 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
hn::Store(pre_out, df, x + head_offset + i); hn::Store(pre_out, df, x + head_offset + i);
} }
}); });
} } // interleaved_idx
// Final linear layer. // Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { // TODO: MatMul
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); for (size_t r = 0; r < num_interleaved; ++r) {
float* out_ptr = activations.att_sums.Row(batch_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(r);
float* out_ptr = activations.att_sums.Row(r);
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr, layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
pool); pool);
} }
} } // GriffinRecurrent
// Wrapper class; holds arguments in member variables to shorten call sites. // Wrapper class; holds arguments in member variables to shorten call sites.
template <typename T> template <typename T>
@ -219,11 +235,7 @@ class GemmaAttention {
activations_.weights_config.attention_window_sizes[layer] == activations_.weights_config.attention_window_sizes[layer] ==
activations_.seq_len; activations_.seq_len;
// TODO: add a config flag instead of hardcoding the model. // TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && if (is_global_layer && IsVLM(activations_.weights_config.model)) {
(activations_.weights_config.model == Model::GEMMA3_4B ||
activations_.weights_config.model == Model::GEMMA3_12B ||
activations_.weights_config.model == Model::GEMMA3_27B ||
activations_.weights_config.model == Model::GEMMA3_1B)) {
inv_timescale = activations_.inv_timescale_global.Packed(); inv_timescale = activations_.inv_timescale_global.Packed();
} }
// PostQKType::Rope // PostQKType::Rope
@ -454,7 +466,6 @@ class GemmaAttention {
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
pos, start_pos, last_pos); pos, start_pos, last_pos);
}); });
} }
@ -587,14 +598,12 @@ HWY_NOINLINE void Attention(
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer, GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches)(); activations, layer_weights, div_seq_len, kv_caches)();
} else { } else {
// Only reached if the model is Griffin. HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// The kv_caches are allocated only for the griffin layers, so we need to // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// map the layer index to the griffin layer index. // so map `layer` to the Griffin layer index.
auto type = layer_weights->layer_config.type; const size_t griffin_layer =
size_t layer_of_type =
activations.weights_config.NumLayersOfTypeBefore(type, layer); activations.weights_config.NumLayersOfTypeBefore(type, layer);
HWY_ASSERT(queries_pos.size() == 1); GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
GriffinRecurrent(queries_pos[0], num_tokens, layer_of_type, activations,
layer_weights, kv_caches); layer_weights, kv_caches);
} }
} }
@ -1056,7 +1065,6 @@ HWY_NOINLINE void Prefill(
// threads to parallelizing over queries, but for simplicity we assign them // threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul. // all to MatMul.
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
HWY_DASSERT(max_tbatch_size <= activations.x.Rows());
// For each query. `qi` is within the batch, not the global query index. // For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < num_queries; ++qi) {
@ -1397,6 +1405,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches, const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) { TimingInfo& timing_info) {
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized. // Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t i = 0; i < kv_caches.size(); ++i) {
if (queries_pos_in[i] == 0) { if (queries_pos_in[i] == 0) {
@ -1510,17 +1519,10 @@ void GenerateBatchT(const ModelStore& model,
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries); HWY_ASSERT(kv_caches.size() >= num_queries);
// Griffin does not support query batching. const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
for (const LayerConfig& layer_config : model.Config().layer_configs) {
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
max_qbatch_size = 1;
break;
}
}
const size_t max_batch_size = const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(model.Config(), max_batch_size, env); Activations activations(model.Config(), max_batch_size, env);
for (size_t qbatch_start = 0; qbatch_start < num_queries; for (size_t qbatch_start = 0; qbatch_start < num_queries;
@ -1528,7 +1530,6 @@ void GenerateBatchT(const ModelStore& model,
// Generate one batch of tokens from `qbatch_size` queries. // Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size = const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size); HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
activations.SetBatchSize(qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size); qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);

View File

@ -25,8 +25,9 @@
namespace gcpp { namespace gcpp {
void KVCache::ZeroGriffinCache() { void KVCache::ZeroGriffinCache() {
if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache); if (griffin_layers == 0) return;
if (rglru_cache.HasPtr()) ZeroInit(rglru_cache); ZeroInit(conv1d_cache);
ZeroInit(rglru_cache);
} }
static size_t GriffinConv1dCols(const ModelConfig& config) { static size_t GriffinConv1dCols(const ModelConfig& config) {
@ -34,7 +35,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
for (const auto& layer_config : config.layer_configs) { for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
} }
return conv1d_width == 0 ? 0 : conv1d_width - 1; // The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
// hence allocate conv1d_width * model_dim total columns.
return conv1d_width * config.model_dim;
} }
// prefill_tbatch_size is the maximum number of tokens from one query to // prefill_tbatch_size is the maximum number of tokens from one query to
@ -42,12 +45,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size) KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
: griffin_layers( : griffin_layers(
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)), config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
griffin_conv1d_cols(GriffinConv1dCols(config)), conv1d_cache("conv1d_cache",
// TODO(patrickms): Add query batching support for Griffin. Extents2D(griffin_layers, GriffinConv1dCols(config)),
conv1d_cache( MatPadding::kOdd),
"conv1d_cache",
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
MatPadding::kOdd),
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim), rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
MatPadding::kOdd) { MatPadding::kOdd) {
// TODO: move to MatStorageT. // TODO: move to MatStorageT.

View File

@ -31,7 +31,6 @@ struct KVCache {
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
size_t griffin_layers = 0; size_t griffin_layers = 0;
size_t griffin_conv1d_cols = 0;
// griffin_layers, griffin_conv1d_cols * config.model_dim // griffin_layers, griffin_conv1d_cols * config.model_dim
MatStorageT<float> conv1d_cache; MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim

View File

@ -97,21 +97,27 @@ void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
SplitW1NUQ(layer_config); SplitW1NUQ(layer_config);
} }
struct TensorToRead {
MatPtr* mat;
BlobRange range;
// Some tensors opt out of padding via kNoPad flags.
MatPadding padding;
};
// Allocates multiple in parallel and binds to NUMA nodes. // Allocates multiple in parallel and binds to NUMA nodes.
static void AllocateAndBindAll(const std::vector<MatPtr*>& mats, static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
MatPadding padding,
std::vector<MatOwner>& owners, std::vector<MatOwner>& owners,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
const size_t start = owners.size(); const size_t start = owners.size();
owners.resize(start + mats.size()); owners.resize(start + tensors.size());
MMParallel parallel(ThreadingContext::Get()); MMParallel parallel(ThreadingContext::Get());
// Allocate in parallel because faulting in large tensors is slow. // Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) { pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
owners[start + task].AllocateFor(*mats[task], padding); owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
// TODO(janwas): MatMul outputs will later also be BF16. // TODO(janwas): MatMul outputs will later also be BF16.
BindB(*mats[task], sizeof(float), parallel); BindB(*tensors[task].mat, sizeof(float), parallel);
}); });
} }
@ -155,55 +161,54 @@ MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
return MapPtr(); return MapPtr();
} }
static void MapAll(const std::vector<MatPtr*>& mats, static void MapAll(const std::vector<TensorToRead>& tensors,
const std::vector<BlobRange>& ranges, const MapPtr& mapped) { const MapPtr& mapped) {
PROFILER_ZONE("Startup.Weights.Map"); PROFILER_ZONE("Startup.Weights.Map");
for (size_t i = 0; i < mats.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
// SetPtr does not change the stride, but it is expected to be packed // SetPtr does not change the stride, but it is expected to be packed
// because that is what Compress() writes to the file. // because that is what Compress() writes to the file.
const size_t mat_bytes = mats[i]->PackedBytes(); const size_t mat_bytes = tensors[i].mat->PackedBytes();
// Ensure blob size matches that computed from metadata. // Ensure blob size matches that computed from metadata.
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name()); HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name());
mats[i]->SetPtr(const_cast<uint8_t*>(mapped.get() + ranges[i].offset), tensors[i].mat->SetPtr(
mats[i]->Stride()); const_cast<uint8_t*>(mapped.get() + tensors[i].range.offset),
tensors[i].mat->Stride());
} }
} }
std::vector<IOBatch> MakeBatches(const std::vector<BlobRange>& ranges, std::vector<IOBatch> MakeBatches(const std::vector<TensorToRead>& tensors,
const std::vector<MatPtr*>& mats,
const uint64_t file_bytes) { const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches"); PROFILER_ZONE("Startup.Weights.MakeBatches");
// Batches must be contiguous but blobs are padded, hence at least one // Batches must be contiguous but blobs are padded, hence at least one
// batch per tensor, and more when tensor rows exceed the batch size. // batch per tensor, and more when tensor rows exceed the batch size.
std::vector<IOBatch> batches; std::vector<IOBatch> batches;
batches.reserve(mats.size()); batches.reserve(tensors.size());
for (size_t i = 0; i < mats.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
uint64_t offset = ranges[i].offset; const BlobRange& range = tensors[i].range;
HWY_ASSERT(ranges[i].End() <= file_bytes); MatPtr& mat = *tensors[i].mat;
uint64_t offset = range.offset;
HWY_ASSERT(range.End() <= file_bytes);
batches.emplace_back(offset, ranges[i].key_idx); batches.emplace_back(offset, range.key_idx);
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes(); const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
// Caution, `RowT` requires knowledge of the actual type. We instead use const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
// the first row, which is the same for any type, and advance the *byte* uint8_t* row_bytes = mat.RowBytes(0);
// pointer by the *byte* stride. for (size_t r = 0; r < mat.Rows(); ++r) {
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes(); if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
uint8_t* row = mats[i]->RowT<uint8_t>(0); batches.emplace_back(offset, range.key_idx);
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, ranges[i].key_idx);
// Adding to an empty batch is always successful. // Adding to an empty batch is always successful.
HWY_ASSERT(batches.back().Add(row, file_bytes_per_row)); HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
} }
offset += file_bytes_per_row; offset += file_bytes_per_row;
row += mem_stride_bytes; row_bytes += mem_stride_bytes;
// Keep the in-memory row padding uninitialized so msan detects any use. // Keep the in-memory row padding uninitialized so msan detects any use.
} }
HWY_ASSERT(offset == ranges[i].End()); HWY_ASSERT(offset == range.End());
} }
HWY_ASSERT(batches.size() >= mats.size()); HWY_ASSERT(batches.size() >= tensors.size());
return batches; return batches;
} }
@ -228,16 +233,14 @@ static void ReadBatches(const BlobReader& reader,
} }
// Aborts on error. // Aborts on error.
static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader, static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
const std::vector<BlobRange>& ranges, Tristate map, BlobReader& reader, Tristate map,
std::vector<MatOwner>& mat_owners, std::vector<MatOwner>& mat_owners,
const MatPadding padding, hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size());
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) { if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes()); MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
if (mapped) { if (mapped) {
MapAll(mats, ranges, mapped); MapAll(tensors, mapped);
return; return;
} }
} // otherwise fall through to read mode } // otherwise fall through to read mode
@ -245,32 +248,32 @@ static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
{ {
PROFILER_ZONE("Startup.Weights.Allocate"); PROFILER_ZONE("Startup.Weights.Allocate");
// NOTE: this changes the stride of `mats`! // NOTE: this changes the stride of `mats`!
AllocateAndBindAll(mats, padding, mat_owners, pool); AllocateAndBindAll(tensors, mat_owners, pool);
} }
const std::vector<IOBatch> batches = const std::vector<IOBatch> batches =
MakeBatches(ranges, mats, reader.file_bytes()); MakeBatches(tensors, reader.file_bytes());
ReadBatches(reader, batches, pool); ReadBatches(reader, batches, pool);
} }
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
Tristate map, hwy::ThreadPool& pool) { Tristate map, hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from. // List of tensors to read/map, and where from.
std::vector<MatPtr*> mats; std::vector<TensorToRead> tensors;
std::vector<BlobRange> ranges;
// Padding is inserted when reading row by row, except for NUQ tensors.
const MatPadding padding = MatPadding::kOdd;
AllocatePointer(model.Config()); AllocatePointer(model.Config());
// Enumerate all weights (negligible cost). // Enumerate all weights (negligible cost).
CallT([&](const auto& weights) { CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
const MatPadding padding = (t.flags & TensorArgs::kNoPad)
? MatPadding::kPacked
: MatPadding::kOdd;
size_t key_idx; size_t key_idx;
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
mats.push_back(&t.mat); tensors.push_back({.mat = &t.mat,
ranges.push_back(reader.Range(key_idx)); .range = reader.Range(key_idx),
.padding = padding});
return; return;
} }
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found. if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
@ -278,7 +281,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
}); });
}); });
MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool); MapOrReadAll(tensors, reader, map, mat_owners_, pool);
Fixup(pool); Fixup(pool);
} }
@ -338,9 +341,9 @@ void WeightsOwner::LogWeightStatsF32() {
printf("[scale=%f] ", t.mat.Scale()); printf("[scale=%f] ", t.mat.Scale());
} }
hwy::Stats stats; hwy::Stats stats;
HWY_ASSERT(t.mat.GetType() == Type::kF32); const MatPtrT<float> mat_f(t.mat);
for (size_t r = 0; r < t.mat.Rows(); ++r) { for (size_t r = 0; r < t.mat.Rows(); ++r) {
const float* HWY_RESTRICT row = t.mat.RowT<float>(r); const float* HWY_RESTRICT row = mat_f.Row(r);
for (size_t c = 0; c < t.mat.Cols(); ++c) { for (size_t c = 0; c < t.mat.Cols(); ++c) {
stats.Notify(row[c]); stats.Notify(row[c]);
} }

View File

@ -43,24 +43,27 @@ struct TensorArgs {
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs // name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`. // (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
// `flags` is a combination of zero or more `Flags`. // `flags` is a combination of zero or more `Flags`.
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2, TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags)
int flags)
: mat(mat), : mat(mat),
other_mat1(other_mat1), other_mat1(other_mat1),
other_mat2(other_mat2), other_mat2(other_mat2),
flags(flags) {} flags(flags) {}
MatPtr& mat; MatPtr& mat;
const MatPtr* other_mat1; // either/both can be nullptr. MatPtr* other_mat1; // either/both can be nullptr.
const MatPtr* other_mat2; MatPtr* other_mat2;
// TODO: freestanding enum class instead? These are mutually exclusive.
enum Flags { enum Flags {
// Read the tensor from the file and abort if it is not found. // Default: Read the tensor from the file and abort if it is not found.
kMustRead = 0, kMustRead = 0,
// Not an error if the tensor is not present in the file. For example, // Not an error if the tensor is not present in the file. For example,
// the _w1/_w2 tensors are not always present. // the _w1/_w2 tensors are not always present.
kMaybeRead = 1, kMaybeRead = 1,
// Avoid padding tensor rows when reading. Used for some Griffin tensors
// whose index computations do not use Row() accessors.
kNoPad = 2,
}; };
const int flags; const int flags;
}; };
@ -214,8 +217,8 @@ struct LayerWeightsPtrs {
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`. // can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
// Public because also called by `WeightsPtrs`. // Public because also called by `WeightsPtrs`.
template <class Func> template <class Func>
void ForEachTensor(const LayerWeightsPtrs<Weight>* other1, void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
const LayerWeightsPtrs<Weight>* other2, Func func) { LayerWeightsPtrs<Weight>* other2, Func func) {
if (layer_config.type == LayerAttentionType::kVit) { if (layer_config.type == LayerAttentionType::kVit) {
// MHA. // MHA.
func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
@ -248,9 +251,9 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
func(TENSOR_ARGS(griffin.conv_w, kMustRead)); func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad));
func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead)); func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad));
func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
func(TENSOR_ARGS(griffin.a, kMustRead)); func(TENSOR_ARGS(griffin.a, kMustRead));
} }
@ -363,8 +366,8 @@ struct LayerWeightsPtrs {
// For FFN. Fast, only updates pointers. // For FFN. Fast, only updates pointers.
void SplitW1() { void SplitW1() {
// We only use this tensor for Gemma layers. // Used for Gemma and Griffin layers; FFWVit uses different tensors.
if (layer_config.type != LayerAttentionType::kGemma) return; if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2, and backprop/ allocates both. // Files have both or neither of w1 and w2, and backprop/ allocates both.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
@ -514,10 +517,10 @@ struct ModelWeightsPtrs {
// used to copy from another set of weights. Public because called by tests // used to copy from another set of weights. Public because called by tests
// and `WeightsOwner`. // and `WeightsOwner`.
template <class Func> template <class Func>
void ForEachTensor(const ModelWeightsPtrs<Weight>* other1, void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
const ModelWeightsPtrs<Weight>* other2, Func func) { ModelWeightsPtrs<Weight>* other2, Func func) {
const LayerWeightsPtrs<Weight>* other_layer1 = nullptr; LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
const LayerWeightsPtrs<Weight>* other_layer2 = nullptr; LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
func(TENSOR_ARGS(final_norm_scale, kMustRead)); func(TENSOR_ARGS(final_norm_scale, kMustRead));
@ -569,11 +572,12 @@ struct ModelWeightsPtrs {
// Copies only the allocated tensors in `*this` from tensors in `other`. // Copies only the allocated tensors in `*this` from tensors in `other`.
void CopyFrom(const ModelWeightsPtrs<Weight>& other) { void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
ForEachTensor(&other, nullptr, [](const TensorArgs& t) { ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
if (!t.mat.HasPtr()) return; [](const TensorArgs& t) {
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); if (!t.mat.HasPtr()) return;
CopyMat(*t.other_mat1, t.mat); HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
}); CopyMat(*t.other_mat1, t.mat);
});
} }
// Instead of reading, only allocates memory for all tensors. Used by // Instead of reading, only allocates memory for all tensors. Used by

View File

@ -84,6 +84,14 @@ struct Header { // standard layout class
static_assert(sizeof(Header) == 16); static_assert(sizeof(Header) == 16);
} // namespace } // namespace
// A write I/O request, each serviced by one thread in a pool.
struct BlobIO {
BlobIO(BlobRange range, void* data) : range(range), data(data) {}
BlobRange range;
void* data; // Read-only for writes.
};
// Little-endian on-disk representation: a fixed-size `Header`, then a padded // Little-endian on-disk representation: a fixed-size `Header`, then a padded
// variable-length 'directory' of blob keys and their offset/sizes, then the // variable-length 'directory' of blob keys and their offset/sizes, then the
// 'payload' of each blob's data with padding in between, followed by padding to // 'payload' of each blob's data with padding in between, followed by padding to
@ -238,11 +246,11 @@ class BlobStore {
return true; // all OK return true; // all OK
} }
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const { void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO>& writes) const {
const size_t key_idx = 0; // not actually associated with a key/blob const size_t key_idx = 0; // not actually associated with a key/blob
writes.emplace_back( writes.emplace_back(
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx}, BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
// members are const and BlobIO2 requires non-const pointers, and they // members are const and BlobIO requires non-const pointers, and they
// are not modified by file writes. // are not modified by file writes.
const_cast<Header*>(&header_)); const_cast<Header*>(&header_));
writes.emplace_back( writes.emplace_back(
@ -314,7 +322,7 @@ BlobReader::BlobReader(const Path& blob_path)
// Split into chunks for load-balancing even if blob sizes vary. // Split into chunks for load-balancing even if blob sizes vary.
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
uint8_t* data, std::vector<BlobIO2>& writes) { uint8_t* data, std::vector<BlobIO>& writes) {
constexpr size_t kChunkBytes = 4 * 1024 * 1024; constexpr size_t kChunkBytes = 4 * 1024 * 1024;
const uint64_t end = offset + bytes; const uint64_t end = offset + bytes;
// Split into whole chunks and possibly one remainder. // Split into whole chunks and possibly one remainder.
@ -336,7 +344,7 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
static void EnqueueWritesForBlobs(const BlobStore& bs, static void EnqueueWritesForBlobs(const BlobStore& bs,
const hwy::Span<const uint8_t> blobs[], const hwy::Span<const uint8_t> blobs[],
std::vector<uint8_t>& zeros, std::vector<uint8_t>& zeros,
std::vector<BlobIO2>& writes) { std::vector<BlobIO>& writes) {
// All-zero buffer used to write padding to the file without copying the // All-zero buffer used to write padding to the file without copying the
// input blobs. // input blobs.
static constexpr uint8_t kZeros[kBlobAlign] = {0}; static constexpr uint8_t kZeros[kBlobAlign] = {0};
@ -388,7 +396,7 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
HWY_ASSERT(num_blobs != 0); HWY_ASSERT(num_blobs != 0);
HWY_ASSERT(num_blobs == blobs_.size()); HWY_ASSERT(num_blobs == blobs_.size());
std::vector<BlobIO2> writes; std::vector<BlobIO> writes;
writes.reserve(16384); writes.reserve(16384);
const BlobStore bs(num_blobs, keys_.data(), blobs_.data()); const BlobStore bs(num_blobs, keys_.data(), blobs_.data());

View File

@ -44,16 +44,6 @@ struct BlobRange {
size_t key_idx; size_t key_idx;
}; };
// A read or write I/O request, each serviced by one thread in a pool.
struct BlobIO2 {
BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
BlobRange range;
void* data; // Modified only if a read request. Read-only for writes.
};
class BlobStore;
// Reads `BlobStore` header, converts keys to strings and creates a hash map for // Reads `BlobStore` header, converts keys to strings and creates a hash map for
// faster lookups. // faster lookups.
// TODO(janwas): rename to BlobFinder or similar. // TODO(janwas): rename to BlobFinder or similar.

View File

@ -425,7 +425,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
have_timer_stop = hwy::platform::HaveTimerStop(cpu100); have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
} }
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
Allocator& allocator = parallel.allocator(); Allocator& allocator = parallel.allocator();
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;
if (B.Rows() == 1) return; if (B.Rows() == 1) return;
@ -437,8 +437,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx); const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx); const size_t node = parallel.Node(pkg_idx);
uintptr_t begin = uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
reinterpret_cast<uintptr_t>(B.RowT<uint8_t>(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that // B row padding is less than the page size, so only bind the subset that
// is page-aligned. // is page-aligned.
@ -451,7 +450,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
} }
// C is BF16/float, or double for partial // C is BF16/float, or double for partial
void BindC(const MatPtr& C, MMParallel& parallel) { void BindC(MatPtr& C, MMParallel& parallel) {
Allocator& allocator = parallel.allocator(); Allocator& allocator = parallel.allocator();
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;
@ -470,8 +469,7 @@ void BindC(const MatPtr& C, MMParallel& parallel) {
const size_t node = parallel.Node(pkg_idx); const size_t node = parallel.Node(pkg_idx);
for (size_t im = 0; im < C.Rows(); ++im) { for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.MutableRowT<uint8_t>(im) + begin, ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
end - begin, node);
} }
} }
if (HWY_UNLIKELY(!ok)) { if (HWY_UNLIKELY(!ok)) {

View File

@ -172,9 +172,9 @@ class MMParallel {
ThreadingContext& ctx_; ThreadingContext& ctx_;
}; };
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel); void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
// C is BF16/float, or double for partial. // C is BF16/float, or double for partial.
void BindC(const MatPtr& C, MMParallel& parallel); void BindC(MatPtr& C, MMParallel& parallel);
// Per-package storage for packed A, and one global C-shaped `partial` for // Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K). // accumulating partial dot products (sections of K).

View File

@ -41,8 +41,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) {
} }
const size_t row_bytes = to.Cols() * to.ElementBytes(); const size_t row_bytes = to.Cols() * to.ElementBytes();
for (size_t r = 0; r < to.Rows(); ++r) { for (size_t r = 0; r < to.Rows(); ++r) {
const uint8_t* from_row = from.RowT<uint8_t>(r); const uint8_t* from_row = from.RowBytes(r);
uint8_t* to_row = to.RowT<uint8_t>(r); uint8_t* to_row = to.RowBytes(r);
hwy::CopyBytes(from_row, to_row, row_bytes); hwy::CopyBytes(from_row, to_row, row_bytes);
} }
} }
@ -58,7 +58,7 @@ void ZeroInit(MatPtr& mat) {
} }
const size_t row_bytes = mat.Cols() * mat.ElementBytes(); const size_t row_bytes = mat.Cols() * mat.ElementBytes();
for (size_t r = 0; r < mat.Rows(); ++r) { for (size_t r = 0; r < mat.Rows(); ++r) {
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes); hwy::ZeroBytes(mat.RowBytes(r), row_bytes);
} }
} }
@ -71,15 +71,19 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
std::normal_distribution<float> dist(0.0, stddev); std::normal_distribution<float> dist(0.0, stddev);
if (mat.GetType() == Type::kF32) { if (mat.GetType() == Type::kF32) {
MatPtrT<float> mat_f(mat);
for (size_t r = 0; r < mat.Rows(); ++r) { for (size_t r = 0; r < mat.Rows(); ++r) {
float* HWY_RESTRICT row = mat.RowT<float>(r); float* HWY_RESTRICT row = mat_f.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) { for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen); row[c] = dist(gen);
} }
} }
} else { } else {
MatPtrT<double> mat_d(mat);
for (size_t r = 0; r < mat.Rows(); ++r) { for (size_t r = 0; r < mat.Rows(); ++r) {
double* HWY_RESTRICT row = mat.RowT<double>(r); double* HWY_RESTRICT row = mat_d.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) { for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen); row[c] = dist(gen);
} }

View File

@ -91,21 +91,14 @@ class MatPtr : public IFields {
return num_elements_ * element_bytes_; return num_elements_ * element_bytes_;
} }
// Works for any kind of padding. // Works for any kind of padding and element type.
template <typename T> uint8_t* RowBytes(size_t row) {
T* MutableRowT(size_t row) const {
HWY_DASSERT(row < Rows()); HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; return static_cast<uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
} }
template <typename T> const uint8_t* RowBytes(size_t row) const {
T* RowT(size_t row) {
HWY_DASSERT(row < Rows()); HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; return static_cast<const uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
}
template <typename T>
const T* RowT(size_t row) const {
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
} }
Type GetType() const { return type_; } Type GetType() const { return type_; }
@ -209,9 +202,8 @@ class MatPtr : public IFields {
float scale_ = 1.0f; // multiplier for each value, for MatMul. float scale_ = 1.0f; // multiplier for each value, for MatMul.
}; };
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides // Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures
// type-aware accessors (`RowT`), this class is more convenient when accessing // the template argument and `Type` are consistent.
// elements, and ensures the template argument and `Type` are consistent.
template <typename MatT> template <typename MatT>
class MatPtrT : public MatPtr { class MatPtrT : public MatPtr {
public: public:
@ -227,7 +219,9 @@ class MatPtrT : public MatPtr {
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {} : MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
// Copying allowed because the metadata is small. // Copying allowed because the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {} MatPtrT(const MatPtr& other) : MatPtr(other) {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
}
MatPtrT& operator=(const MatPtr& other) { MatPtrT& operator=(const MatPtr& other) {
MatPtr::operator=(other); MatPtr::operator=(other);
return *this; return *this;
@ -252,22 +246,24 @@ class MatPtrT : public MatPtr {
return Packed(); return Packed();
} }
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); } MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
MatT* Row(size_t row) { return this->RowT<MatT>(row); } const MatT* Row(size_t row) const {
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
}
PackedSpan<const MatT> PaddedSpan() const { PackedSpan<const MatT> PaddedSpan() const {
return MakeConstSpan(Row(0), Rows() * Stride()); return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
} }
// For `compress-inl.h` functions, which assume contiguous streams and thus // For `compress-inl.h` functions, which assume contiguous streams and thus
// require packed layout. // require packed layout.
PackedSpan<const MatT> Span() const {
HWY_ASSERT(IsPacked());
return MakeConstSpan(Row(0), num_elements_);
}
PackedSpan<MatT> Span() { PackedSpan<MatT> Span() {
HWY_ASSERT(IsPacked()); HWY_ASSERT(IsPacked());
return MakeSpan(Row(0), num_elements_); return MakeSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
}
PackedSpan<const MatT> Span() const {
HWY_ASSERT(IsPacked());
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
} }
}; };