mirror of https://github.com/google/gemma.cpp.git
Fix RowT issue and improve Griffin (currently still broken)
Use type-safe MatPtrT via dynamic_cast, avoid/remove unsafe RowT activations: Griffin tensors are now padded Griffin: add batching support, fix conv1d_cache allocation weights: bundle to TensorToRead, add kNoPad flag, fix SplitW1 const-correct fix for ForEachTensor blob_store: move BlobIO2 to .cc and rename BlobIO PiperOrigin-RevId: 760610094
This commit is contained in:
parent
d6cfabc2c1
commit
cb188d4a0e
|
|
@ -27,6 +27,8 @@ namespace gcpp {
|
|||
|
||||
namespace {
|
||||
|
||||
using MatPtrF = MatPtrT<float>;
|
||||
|
||||
// Split into two classes so that ForEachTensor only requires two "other"
|
||||
// arguments. This is anyway useful for locality, because `grad` only feeds
|
||||
// into `grad_m` and `grad_v` here.
|
||||
|
|
@ -40,12 +42,11 @@ class AdamUpdateMV {
|
|||
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
|
||||
|
||||
void operator()(const MatPtr& grad, const MatPtr& grad_m,
|
||||
const MatPtr& grad_v) {
|
||||
void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) {
|
||||
for (size_t r = 0; r < grad.Rows(); ++r) {
|
||||
const float* HWY_RESTRICT g = grad.RowT<float>(r);
|
||||
float* HWY_RESTRICT m = grad_m.MutableRowT<float>(r);
|
||||
float* HWY_RESTRICT v = grad_v.MutableRowT<float>(r);
|
||||
const float* HWY_RESTRICT g = grad.Row(r);
|
||||
float* HWY_RESTRICT m = grad_m.Row(r);
|
||||
float* HWY_RESTRICT v = grad_v.Row(r);
|
||||
for (size_t c = 0; c < grad.Cols(); ++c) {
|
||||
m[c] *= beta1_;
|
||||
m[c] += cbeta1_ * g[c];
|
||||
|
|
@ -73,11 +74,12 @@ class AdamUpdateW {
|
|||
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
|
||||
epsilon_(epsilon) {}
|
||||
|
||||
void operator()(MatPtr& weights, const MatPtr& grad_m, const MatPtr& grad_v) {
|
||||
void operator()(MatPtrF& weights, const MatPtrF& grad_m,
|
||||
const MatPtrF& grad_v) {
|
||||
for (size_t r = 0; r < weights.Rows(); ++r) {
|
||||
float* HWY_RESTRICT w = weights.RowT<float>(r);
|
||||
const float* HWY_RESTRICT m = grad_m.RowT<float>(r);
|
||||
const float* HWY_RESTRICT v = grad_v.RowT<float>(r);
|
||||
float* HWY_RESTRICT w = weights.Row(r);
|
||||
const float* HWY_RESTRICT m = grad_m.Row(r);
|
||||
const float* HWY_RESTRICT v = grad_v.Row(r);
|
||||
for (size_t c = 0; c < weights.Cols(); ++c) {
|
||||
const float mhat = m[c] * norm1_;
|
||||
const float vhat = v[c] * norm2_;
|
||||
|
|
@ -100,12 +102,18 @@ void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
|
|||
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
||||
AdamUpdateMV update_mv(beta1, beta2, t);
|
||||
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
|
||||
update_mv(t.mat, *t.other_mat1, *t.other_mat2);
|
||||
const MatPtrF grad_f(t.mat);
|
||||
MatPtrF grad_m_f(*t.other_mat1);
|
||||
MatPtrF grad_v_f(*t.other_mat2);
|
||||
update_mv(grad_f, grad_m_f, grad_v_f);
|
||||
});
|
||||
|
||||
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
|
||||
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
|
||||
update_w(t.mat, *t.other_mat1, *t.other_mat2);
|
||||
MatPtrF weights_f(t.mat);
|
||||
const MatPtrF grad_m_f(*t.other_mat1);
|
||||
const MatPtrF grad_v_f(*t.other_mat2);
|
||||
update_w(weights_f, grad_m_f, grad_v_f);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,7 @@ struct Activations {
|
|||
layer_config(config.layer_configs[0]),
|
||||
seq_len(config.seq_len),
|
||||
cache_pos_size(config.CachePosSize()),
|
||||
is_griffin(layer_config.type ==
|
||||
LayerAttentionType::kGriffinRecurrentBlock),
|
||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
q("q",
|
||||
|
|
@ -58,21 +57,18 @@ struct Activations {
|
|||
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
||||
|
||||
// No padding for Griffin because it does not always use Row().
|
||||
griffin_x("griffin_x",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
pad_),
|
||||
griffin_y("griffin_y",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
pad_),
|
||||
griffin_gate_x(
|
||||
"griffin_gate_x",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
|
||||
griffin_multiplier(
|
||||
"griffin_mul",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
|
||||
|
||||
inv_timescale(CreateInvTimescale(
|
||||
ThreadingContext::Get().allocator, layer_config.qkv_dim,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <memory> // std::make_unique
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
|
|
@ -69,30 +68,44 @@ namespace HWY_NAMESPACE {
|
|||
// Different functions use different naming conventions for the number of
|
||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`.
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||
|
||||
// TODO: add batch query support for Griffin (QueriesPos).
|
||||
template <typename T>
|
||||
HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
||||
size_t num_tokens, size_t griffin_layer,
|
||||
Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
KVCache& kv_cache = kv_caches[0];
|
||||
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D df;
|
||||
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
|
||||
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
||||
|
||||
const size_t heads = layer_weights->layer_config.heads;
|
||||
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
|
||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||
const size_t kHeadDim = model_dim / heads;
|
||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
|
||||
const size_t num_queries = queries_pos.size();
|
||||
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
|
||||
// X / Y linear layers.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
// TODO: MatMul
|
||||
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
|
||||
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
||||
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
||||
activations.pre_att_rms_out.Row(batch_idx),
|
||||
activations.pre_att_rms_out.Row(r),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
|
|
@ -100,18 +113,19 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
}
|
||||
|
||||
// Conv1D.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||
cache[0] = x;
|
||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||
cache[i] =
|
||||
kv_cache.conv1d_cache.Row(layer) +
|
||||
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
|
||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||
}
|
||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||
|
|
@ -119,7 +133,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
auto accum0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
||||
auto wv0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
|
|
@ -136,17 +149,20 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
}
|
||||
|
||||
// RGLRU
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
|
||||
float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer);
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
|
||||
|
||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t kHeadDim = model_dim / heads;
|
||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
size_t head_offset = head * kHeadDim;
|
||||
TwoOfsMatVecAddLoop(
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
|
|
@ -166,7 +182,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
|
||||
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
|
||||
auto log_a = hn::Load(df, a + head_offset + i);
|
||||
|
|
@ -186,17 +201,18 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
});
|
||||
}
|
||||
} // interleaved_idx
|
||||
|
||||
// Final linear layer.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
||||
float* out_ptr = activations.att_sums.Row(batch_idx);
|
||||
// TODO: MatMul
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||
float* out_ptr = activations.att_sums.Row(r);
|
||||
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
||||
pool);
|
||||
}
|
||||
}
|
||||
} // GriffinRecurrent
|
||||
|
||||
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||
template <typename T>
|
||||
|
|
@ -219,11 +235,7 @@ class GemmaAttention {
|
|||
activations_.weights_config.attention_window_sizes[layer] ==
|
||||
activations_.seq_len;
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer &&
|
||||
(activations_.weights_config.model == Model::GEMMA3_4B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_12B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_27B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_1B)) {
|
||||
if (is_global_layer && IsVLM(activations_.weights_config.model)) {
|
||||
inv_timescale = activations_.inv_timescale_global.Packed();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
|
|
@ -454,7 +466,6 @@ class GemmaAttention {
|
|||
|
||||
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
||||
pos, start_pos, last_pos);
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -587,14 +598,12 @@ HWY_NOINLINE void Attention(
|
|||
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
|
||||
activations, layer_weights, div_seq_len, kv_caches)();
|
||||
} else {
|
||||
// Only reached if the model is Griffin.
|
||||
// The kv_caches are allocated only for the griffin layers, so we need to
|
||||
// map the layer index to the griffin layer index.
|
||||
auto type = layer_weights->layer_config.type;
|
||||
size_t layer_of_type =
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
|
||||
// so map `layer` to the Griffin layer index.
|
||||
const size_t griffin_layer =
|
||||
activations.weights_config.NumLayersOfTypeBefore(type, layer);
|
||||
HWY_ASSERT(queries_pos.size() == 1);
|
||||
GriffinRecurrent(queries_pos[0], num_tokens, layer_of_type, activations,
|
||||
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
|
||||
layer_weights, kv_caches);
|
||||
}
|
||||
}
|
||||
|
|
@ -1056,7 +1065,6 @@ HWY_NOINLINE void Prefill(
|
|||
// threads to parallelizing over queries, but for simplicity we assign them
|
||||
// all to MatMul.
|
||||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||
HWY_DASSERT(max_tbatch_size <= activations.x.Rows());
|
||||
|
||||
// For each query. `qi` is within the batch, not the global query index.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
|
|
@ -1397,6 +1405,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
if (queries_pos_in[i] == 0) {
|
||||
|
|
@ -1510,17 +1519,10 @@ void GenerateBatchT(const ModelStore& model,
|
|||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
||||
// Griffin does not support query batching.
|
||||
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||
for (const LayerConfig& layer_config : model.Config().layer_configs) {
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
max_qbatch_size = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||
const size_t max_batch_size =
|
||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||
|
||||
Activations activations(model.Config(), max_batch_size, env);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
|
|
@ -1528,7 +1530,6 @@ void GenerateBatchT(const ModelStore& model,
|
|||
// Generate one batch of tokens from `qbatch_size` queries.
|
||||
const size_t qbatch_size =
|
||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||
activations.SetBatchSize(qbatch_size);
|
||||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
||||
qbatch_size);
|
||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
|
|
|
|||
|
|
@ -25,8 +25,9 @@
|
|||
namespace gcpp {
|
||||
|
||||
void KVCache::ZeroGriffinCache() {
|
||||
if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache);
|
||||
if (rglru_cache.HasPtr()) ZeroInit(rglru_cache);
|
||||
if (griffin_layers == 0) return;
|
||||
ZeroInit(conv1d_cache);
|
||||
ZeroInit(rglru_cache);
|
||||
}
|
||||
|
||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||
|
|
@ -34,7 +35,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
|
|||
for (const auto& layer_config : config.layer_configs) {
|
||||
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
return conv1d_width == 0 ? 0 : conv1d_width - 1;
|
||||
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
|
||||
// hence allocate conv1d_width * model_dim total columns.
|
||||
return conv1d_width * config.model_dim;
|
||||
}
|
||||
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
|
|
@ -42,11 +45,8 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
|
|||
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
|
||||
: griffin_layers(
|
||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
|
||||
griffin_conv1d_cols(GriffinConv1dCols(config)),
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
conv1d_cache(
|
||||
"conv1d_cache",
|
||||
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
|
||||
conv1d_cache("conv1d_cache",
|
||||
Extents2D(griffin_layers, GriffinConv1dCols(config)),
|
||||
MatPadding::kOdd),
|
||||
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
|
||||
MatPadding::kOdd) {
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ struct KVCache {
|
|||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
|
||||
size_t griffin_layers = 0;
|
||||
size_t griffin_conv1d_cols = 0;
|
||||
// griffin_layers, griffin_conv1d_cols * config.model_dim
|
||||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
|
||||
|
|
|
|||
105
gemma/weights.cc
105
gemma/weights.cc
|
|
@ -97,21 +97,27 @@ void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
|
|||
SplitW1NUQ(layer_config);
|
||||
}
|
||||
|
||||
struct TensorToRead {
|
||||
MatPtr* mat;
|
||||
BlobRange range;
|
||||
// Some tensors opt out of padding via kNoPad flags.
|
||||
MatPadding padding;
|
||||
};
|
||||
|
||||
// Allocates multiple in parallel and binds to NUMA nodes.
|
||||
static void AllocateAndBindAll(const std::vector<MatPtr*>& mats,
|
||||
MatPadding padding,
|
||||
static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
|
||||
std::vector<MatOwner>& owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
const size_t start = owners.size();
|
||||
owners.resize(start + mats.size());
|
||||
owners.resize(start + tensors.size());
|
||||
|
||||
MMParallel parallel(ThreadingContext::Get());
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
owners[start + task].AllocateFor(*mats[task], padding);
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
|
||||
// TODO(janwas): MatMul outputs will later also be BF16.
|
||||
BindB(*mats[task], sizeof(float), parallel);
|
||||
BindB(*tensors[task].mat, sizeof(float), parallel);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -155,55 +161,54 @@ MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
|||
return MapPtr();
|
||||
}
|
||||
|
||||
static void MapAll(const std::vector<MatPtr*>& mats,
|
||||
const std::vector<BlobRange>& ranges, const MapPtr& mapped) {
|
||||
static void MapAll(const std::vector<TensorToRead>& tensors,
|
||||
const MapPtr& mapped) {
|
||||
PROFILER_ZONE("Startup.Weights.Map");
|
||||
for (size_t i = 0; i < mats.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
// SetPtr does not change the stride, but it is expected to be packed
|
||||
// because that is what Compress() writes to the file.
|
||||
const size_t mat_bytes = mats[i]->PackedBytes();
|
||||
const size_t mat_bytes = tensors[i].mat->PackedBytes();
|
||||
// Ensure blob size matches that computed from metadata.
|
||||
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
|
||||
HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name());
|
||||
|
||||
mats[i]->SetPtr(const_cast<uint8_t*>(mapped.get() + ranges[i].offset),
|
||||
mats[i]->Stride());
|
||||
tensors[i].mat->SetPtr(
|
||||
const_cast<uint8_t*>(mapped.get() + tensors[i].range.offset),
|
||||
tensors[i].mat->Stride());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<IOBatch> MakeBatches(const std::vector<BlobRange>& ranges,
|
||||
const std::vector<MatPtr*>& mats,
|
||||
std::vector<IOBatch> MakeBatches(const std::vector<TensorToRead>& tensors,
|
||||
const uint64_t file_bytes) {
|
||||
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||
// Batches must be contiguous but blobs are padded, hence at least one
|
||||
// batch per tensor, and more when tensor rows exceed the batch size.
|
||||
std::vector<IOBatch> batches;
|
||||
batches.reserve(mats.size());
|
||||
batches.reserve(tensors.size());
|
||||
|
||||
for (size_t i = 0; i < mats.size(); ++i) {
|
||||
uint64_t offset = ranges[i].offset;
|
||||
HWY_ASSERT(ranges[i].End() <= file_bytes);
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const BlobRange& range = tensors[i].range;
|
||||
MatPtr& mat = *tensors[i].mat;
|
||||
uint64_t offset = range.offset;
|
||||
HWY_ASSERT(range.End() <= file_bytes);
|
||||
|
||||
batches.emplace_back(offset, ranges[i].key_idx);
|
||||
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
|
||||
// Caution, `RowT` requires knowledge of the actual type. We instead use
|
||||
// the first row, which is the same for any type, and advance the *byte*
|
||||
// pointer by the *byte* stride.
|
||||
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
|
||||
uint8_t* row = mats[i]->RowT<uint8_t>(0);
|
||||
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
|
||||
if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch.
|
||||
batches.emplace_back(offset, ranges[i].key_idx);
|
||||
batches.emplace_back(offset, range.key_idx);
|
||||
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
|
||||
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
|
||||
uint8_t* row_bytes = mat.RowBytes(0);
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
|
||||
batches.emplace_back(offset, range.key_idx);
|
||||
// Adding to an empty batch is always successful.
|
||||
HWY_ASSERT(batches.back().Add(row, file_bytes_per_row));
|
||||
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
|
||||
}
|
||||
offset += file_bytes_per_row;
|
||||
row += mem_stride_bytes;
|
||||
row_bytes += mem_stride_bytes;
|
||||
// Keep the in-memory row padding uninitialized so msan detects any use.
|
||||
}
|
||||
HWY_ASSERT(offset == ranges[i].End());
|
||||
HWY_ASSERT(offset == range.End());
|
||||
}
|
||||
|
||||
HWY_ASSERT(batches.size() >= mats.size());
|
||||
HWY_ASSERT(batches.size() >= tensors.size());
|
||||
return batches;
|
||||
}
|
||||
|
||||
|
|
@ -228,16 +233,14 @@ static void ReadBatches(const BlobReader& reader,
|
|||
}
|
||||
|
||||
// Aborts on error.
|
||||
static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||
const std::vector<BlobRange>& ranges, Tristate map,
|
||||
static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
|
||||
BlobReader& reader, Tristate map,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
const MatPadding padding, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(mats.size() == ranges.size());
|
||||
|
||||
hwy::ThreadPool& pool) {
|
||||
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
||||
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
|
||||
if (mapped) {
|
||||
MapAll(mats, ranges, mapped);
|
||||
MapAll(tensors, mapped);
|
||||
return;
|
||||
}
|
||||
} // otherwise fall through to read mode
|
||||
|
|
@ -245,32 +248,32 @@ static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
|||
{
|
||||
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||
// NOTE: this changes the stride of `mats`!
|
||||
AllocateAndBindAll(mats, padding, mat_owners, pool);
|
||||
AllocateAndBindAll(tensors, mat_owners, pool);
|
||||
}
|
||||
|
||||
const std::vector<IOBatch> batches =
|
||||
MakeBatches(ranges, mats, reader.file_bytes());
|
||||
MakeBatches(tensors, reader.file_bytes());
|
||||
ReadBatches(reader, batches, pool);
|
||||
}
|
||||
|
||||
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||
Tristate map, hwy::ThreadPool& pool) {
|
||||
// List of tensors to read/map, and where from.
|
||||
std::vector<MatPtr*> mats;
|
||||
std::vector<BlobRange> ranges;
|
||||
|
||||
// Padding is inserted when reading row by row, except for NUQ tensors.
|
||||
const MatPadding padding = MatPadding::kOdd;
|
||||
std::vector<TensorToRead> tensors;
|
||||
|
||||
AllocatePointer(model.Config());
|
||||
|
||||
// Enumerate all weights (negligible cost).
|
||||
CallT([&](const auto& weights) {
|
||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
const MatPadding padding = (t.flags & TensorArgs::kNoPad)
|
||||
? MatPadding::kPacked
|
||||
: MatPadding::kOdd;
|
||||
size_t key_idx;
|
||||
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
||||
mats.push_back(&t.mat);
|
||||
ranges.push_back(reader.Range(key_idx));
|
||||
tensors.push_back({.mat = &t.mat,
|
||||
.range = reader.Range(key_idx),
|
||||
.padding = padding});
|
||||
return;
|
||||
}
|
||||
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
|
||||
|
|
@ -278,7 +281,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
|||
});
|
||||
});
|
||||
|
||||
MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool);
|
||||
MapOrReadAll(tensors, reader, map, mat_owners_, pool);
|
||||
|
||||
Fixup(pool);
|
||||
}
|
||||
|
|
@ -338,9 +341,9 @@ void WeightsOwner::LogWeightStatsF32() {
|
|||
printf("[scale=%f] ", t.mat.Scale());
|
||||
}
|
||||
hwy::Stats stats;
|
||||
HWY_ASSERT(t.mat.GetType() == Type::kF32);
|
||||
const MatPtrT<float> mat_f(t.mat);
|
||||
for (size_t r = 0; r < t.mat.Rows(); ++r) {
|
||||
const float* HWY_RESTRICT row = t.mat.RowT<float>(r);
|
||||
const float* HWY_RESTRICT row = mat_f.Row(r);
|
||||
for (size_t c = 0; c < t.mat.Cols(); ++c) {
|
||||
stats.Notify(row[c]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,24 +43,27 @@ struct TensorArgs {
|
|||
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
|
||||
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
|
||||
// `flags` is a combination of zero or more `Flags`.
|
||||
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
|
||||
int flags)
|
||||
TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags)
|
||||
: mat(mat),
|
||||
other_mat1(other_mat1),
|
||||
other_mat2(other_mat2),
|
||||
flags(flags) {}
|
||||
|
||||
MatPtr& mat;
|
||||
const MatPtr* other_mat1; // either/both can be nullptr.
|
||||
const MatPtr* other_mat2;
|
||||
MatPtr* other_mat1; // either/both can be nullptr.
|
||||
MatPtr* other_mat2;
|
||||
|
||||
// TODO: freestanding enum class instead? These are mutually exclusive.
|
||||
enum Flags {
|
||||
// Read the tensor from the file and abort if it is not found.
|
||||
// Default: Read the tensor from the file and abort if it is not found.
|
||||
kMustRead = 0,
|
||||
|
||||
// Not an error if the tensor is not present in the file. For example,
|
||||
// the _w1/_w2 tensors are not always present.
|
||||
kMaybeRead = 1,
|
||||
|
||||
// Avoid padding tensor rows when reading. Used for some Griffin tensors
|
||||
// whose index computations do not use Row() accessors.
|
||||
kNoPad = 2,
|
||||
};
|
||||
const int flags;
|
||||
};
|
||||
|
|
@ -214,8 +217,8 @@ struct LayerWeightsPtrs {
|
|||
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
||||
// Public because also called by `WeightsPtrs`.
|
||||
template <class Func>
|
||||
void ForEachTensor(const LayerWeightsPtrs<Weight>* other1,
|
||||
const LayerWeightsPtrs<Weight>* other2, Func func) {
|
||||
void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
|
||||
LayerWeightsPtrs<Weight>* other2, Func func) {
|
||||
if (layer_config.type == LayerAttentionType::kVit) {
|
||||
// MHA.
|
||||
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
||||
|
|
@ -248,9 +251,9 @@ struct LayerWeightsPtrs {
|
|||
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.conv_w, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad));
|
||||
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.gate_w, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad));
|
||||
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.a, kMustRead));
|
||||
}
|
||||
|
|
@ -363,8 +366,8 @@ struct LayerWeightsPtrs {
|
|||
|
||||
// For FFN. Fast, only updates pointers.
|
||||
void SplitW1() {
|
||||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
||||
if (layer_config.type == LayerAttentionType::kVit) return;
|
||||
|
||||
// Files have both or neither of w1 and w2, and backprop/ allocates both.
|
||||
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||
|
|
@ -514,10 +517,10 @@ struct ModelWeightsPtrs {
|
|||
// used to copy from another set of weights. Public because called by tests
|
||||
// and `WeightsOwner`.
|
||||
template <class Func>
|
||||
void ForEachTensor(const ModelWeightsPtrs<Weight>* other1,
|
||||
const ModelWeightsPtrs<Weight>* other2, Func func) {
|
||||
const LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
||||
const LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
||||
void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
|
||||
ModelWeightsPtrs<Weight>* other2, Func func) {
|
||||
LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
||||
LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
||||
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
||||
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
||||
|
||||
|
|
@ -569,7 +572,8 @@ struct ModelWeightsPtrs {
|
|||
|
||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
||||
ForEachTensor(&other, nullptr, [](const TensorArgs& t) {
|
||||
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
|
||||
[](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
||||
CopyMat(*t.other_mat1, t.mat);
|
||||
|
|
|
|||
|
|
@ -84,6 +84,14 @@ struct Header { // standard layout class
|
|||
static_assert(sizeof(Header) == 16);
|
||||
} // namespace
|
||||
|
||||
// A write I/O request, each serviced by one thread in a pool.
|
||||
struct BlobIO {
|
||||
BlobIO(BlobRange range, void* data) : range(range), data(data) {}
|
||||
|
||||
BlobRange range;
|
||||
void* data; // Read-only for writes.
|
||||
};
|
||||
|
||||
// Little-endian on-disk representation: a fixed-size `Header`, then a padded
|
||||
// variable-length 'directory' of blob keys and their offset/sizes, then the
|
||||
// 'payload' of each blob's data with padding in between, followed by padding to
|
||||
|
|
@ -238,11 +246,11 @@ class BlobStore {
|
|||
return true; // all OK
|
||||
}
|
||||
|
||||
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
|
||||
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO>& writes) const {
|
||||
const size_t key_idx = 0; // not actually associated with a key/blob
|
||||
writes.emplace_back(
|
||||
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
|
||||
// members are const and BlobIO2 requires non-const pointers, and they
|
||||
// members are const and BlobIO requires non-const pointers, and they
|
||||
// are not modified by file writes.
|
||||
const_cast<Header*>(&header_));
|
||||
writes.emplace_back(
|
||||
|
|
@ -314,7 +322,7 @@ BlobReader::BlobReader(const Path& blob_path)
|
|||
|
||||
// Split into chunks for load-balancing even if blob sizes vary.
|
||||
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
||||
uint8_t* data, std::vector<BlobIO2>& writes) {
|
||||
uint8_t* data, std::vector<BlobIO>& writes) {
|
||||
constexpr size_t kChunkBytes = 4 * 1024 * 1024;
|
||||
const uint64_t end = offset + bytes;
|
||||
// Split into whole chunks and possibly one remainder.
|
||||
|
|
@ -336,7 +344,7 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
|||
static void EnqueueWritesForBlobs(const BlobStore& bs,
|
||||
const hwy::Span<const uint8_t> blobs[],
|
||||
std::vector<uint8_t>& zeros,
|
||||
std::vector<BlobIO2>& writes) {
|
||||
std::vector<BlobIO>& writes) {
|
||||
// All-zero buffer used to write padding to the file without copying the
|
||||
// input blobs.
|
||||
static constexpr uint8_t kZeros[kBlobAlign] = {0};
|
||||
|
|
@ -388,7 +396,7 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
|||
HWY_ASSERT(num_blobs != 0);
|
||||
HWY_ASSERT(num_blobs == blobs_.size());
|
||||
|
||||
std::vector<BlobIO2> writes;
|
||||
std::vector<BlobIO> writes;
|
||||
writes.reserve(16384);
|
||||
|
||||
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());
|
||||
|
|
|
|||
|
|
@ -44,16 +44,6 @@ struct BlobRange {
|
|||
size_t key_idx;
|
||||
};
|
||||
|
||||
// A read or write I/O request, each serviced by one thread in a pool.
|
||||
struct BlobIO2 {
|
||||
BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
|
||||
|
||||
BlobRange range;
|
||||
void* data; // Modified only if a read request. Read-only for writes.
|
||||
};
|
||||
|
||||
class BlobStore;
|
||||
|
||||
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
|
||||
// faster lookups.
|
||||
// TODO(janwas): rename to BlobFinder or similar.
|
||||
|
|
|
|||
|
|
@ -425,7 +425,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
|||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
}
|
||||
|
||||
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||
Allocator& allocator = parallel.allocator();
|
||||
if (!allocator.ShouldBind()) return;
|
||||
if (B.Rows() == 1) return;
|
||||
|
|
@ -437,8 +437,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
|||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
uintptr_t begin =
|
||||
reinterpret_cast<uintptr_t>(B.RowT<uint8_t>(rows_b.begin()));
|
||||
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
|
||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
||||
// B row padding is less than the page size, so only bind the subset that
|
||||
// is page-aligned.
|
||||
|
|
@ -451,7 +450,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
|||
}
|
||||
|
||||
// C is BF16/float, or double for partial
|
||||
void BindC(const MatPtr& C, MMParallel& parallel) {
|
||||
void BindC(MatPtr& C, MMParallel& parallel) {
|
||||
Allocator& allocator = parallel.allocator();
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
|
|
@ -470,8 +469,7 @@ void BindC(const MatPtr& C, MMParallel& parallel) {
|
|||
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
for (size_t im = 0; im < C.Rows(); ++im) {
|
||||
ok &= allocator.BindMemory(C.MutableRowT<uint8_t>(im) + begin,
|
||||
end - begin, node);
|
||||
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(!ok)) {
|
||||
|
|
|
|||
|
|
@ -172,9 +172,9 @@ class MMParallel {
|
|||
ThreadingContext& ctx_;
|
||||
};
|
||||
|
||||
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
||||
// C is BF16/float, or double for partial.
|
||||
void BindC(const MatPtr& C, MMParallel& parallel);
|
||||
void BindC(MatPtr& C, MMParallel& parallel);
|
||||
|
||||
// Per-package storage for packed A, and one global C-shaped `partial` for
|
||||
// accumulating partial dot products (sections of K).
|
||||
|
|
|
|||
14
util/mat.cc
14
util/mat.cc
|
|
@ -41,8 +41,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) {
|
|||
}
|
||||
const size_t row_bytes = to.Cols() * to.ElementBytes();
|
||||
for (size_t r = 0; r < to.Rows(); ++r) {
|
||||
const uint8_t* from_row = from.RowT<uint8_t>(r);
|
||||
uint8_t* to_row = to.RowT<uint8_t>(r);
|
||||
const uint8_t* from_row = from.RowBytes(r);
|
||||
uint8_t* to_row = to.RowBytes(r);
|
||||
hwy::CopyBytes(from_row, to_row, row_bytes);
|
||||
}
|
||||
}
|
||||
|
|
@ -58,7 +58,7 @@ void ZeroInit(MatPtr& mat) {
|
|||
}
|
||||
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes);
|
||||
hwy::ZeroBytes(mat.RowBytes(r), row_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -71,15 +71,19 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
|
|||
|
||||
std::normal_distribution<float> dist(0.0, stddev);
|
||||
if (mat.GetType() == Type::kF32) {
|
||||
MatPtrT<float> mat_f(mat);
|
||||
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
float* HWY_RESTRICT row = mat.RowT<float>(r);
|
||||
float* HWY_RESTRICT row = mat_f.Row(r);
|
||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||
row[c] = dist(gen);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MatPtrT<double> mat_d(mat);
|
||||
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
double* HWY_RESTRICT row = mat.RowT<double>(r);
|
||||
double* HWY_RESTRICT row = mat_d.Row(r);
|
||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||
row[c] = dist(gen);
|
||||
}
|
||||
|
|
|
|||
44
util/mat.h
44
util/mat.h
|
|
@ -91,21 +91,14 @@ class MatPtr : public IFields {
|
|||
return num_elements_ * element_bytes_;
|
||||
}
|
||||
|
||||
// Works for any kind of padding.
|
||||
template <typename T>
|
||||
T* MutableRowT(size_t row) const {
|
||||
// Works for any kind of padding and element type.
|
||||
uint8_t* RowBytes(size_t row) {
|
||||
HWY_DASSERT(row < Rows());
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||
return static_cast<uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
|
||||
}
|
||||
template <typename T>
|
||||
T* RowT(size_t row) {
|
||||
const uint8_t* RowBytes(size_t row) const {
|
||||
HWY_DASSERT(row < Rows());
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||
}
|
||||
template <typename T>
|
||||
const T* RowT(size_t row) const {
|
||||
HWY_DASSERT(row < Rows());
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
|
||||
return static_cast<const uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
|
||||
}
|
||||
|
||||
Type GetType() const { return type_; }
|
||||
|
|
@ -209,9 +202,8 @@ class MatPtr : public IFields {
|
|||
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
||||
};
|
||||
|
||||
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides
|
||||
// type-aware accessors (`RowT`), this class is more convenient when accessing
|
||||
// elements, and ensures the template argument and `Type` are consistent.
|
||||
// Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures
|
||||
// the template argument and `Type` are consistent.
|
||||
template <typename MatT>
|
||||
class MatPtrT : public MatPtr {
|
||||
public:
|
||||
|
|
@ -227,7 +219,9 @@ class MatPtrT : public MatPtr {
|
|||
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
|
||||
|
||||
// Copying allowed because the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {
|
||||
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
|
||||
}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
MatPtr::operator=(other);
|
||||
return *this;
|
||||
|
|
@ -252,22 +246,24 @@ class MatPtrT : public MatPtr {
|
|||
return Packed();
|
||||
}
|
||||
|
||||
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
|
||||
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
|
||||
MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
|
||||
const MatT* Row(size_t row) const {
|
||||
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
|
||||
}
|
||||
|
||||
PackedSpan<const MatT> PaddedSpan() const {
|
||||
return MakeConstSpan(Row(0), Rows() * Stride());
|
||||
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
|
||||
}
|
||||
|
||||
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
||||
// require packed layout.
|
||||
PackedSpan<const MatT> Span() const {
|
||||
HWY_ASSERT(IsPacked());
|
||||
return MakeConstSpan(Row(0), num_elements_);
|
||||
}
|
||||
PackedSpan<MatT> Span() {
|
||||
HWY_ASSERT(IsPacked());
|
||||
return MakeSpan(Row(0), num_elements_);
|
||||
return MakeSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
|
||||
}
|
||||
PackedSpan<const MatT> Span() const {
|
||||
HWY_ASSERT(IsPacked());
|
||||
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue