mirror of https://github.com/google/gemma.cpp.git
Rewrote flash attention to use BF16, transpose k and v, rewrote the task distribution, increase parallelism on decode, and use double the registers for the core of flash attention.
PiperOrigin-RevId: 868146247
This commit is contained in:
parent
76d7951242
commit
a814aa411e
|
|
@ -547,6 +547,7 @@ cc_library(
|
|||
deps = [
|
||||
":basics",
|
||||
":configs",
|
||||
":flash_structs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
|
|
@ -594,6 +595,11 @@ cc_test(
|
|||
|
||||
INTERNAL_DEPS = []
|
||||
|
||||
cc_library(
|
||||
name = "flash_structs",
|
||||
hdrs = ["gemma/flash_structs.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attention",
|
||||
srcs = [
|
||||
|
|
@ -603,7 +609,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"gemma/attention.h",
|
||||
"gemma/flash_attention.h",
|
||||
"gemma/flash_structs.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"gemma/gemma-inl.h",
|
||||
|
|
@ -612,6 +617,7 @@ cc_library(
|
|||
":activations",
|
||||
":basics",
|
||||
":configs",
|
||||
":flash_structs",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":matmul",
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "gemma/gemma_args.h" // AttentionImpl
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tensor_stats.h"
|
||||
|
|
@ -52,10 +53,13 @@ struct AttentionActivations {
|
|||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
|
||||
const Allocator& allocator,
|
||||
size_t max_workers, const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||
// MHA and does not use an external KV cache.
|
||||
: rep_factor(max_workers *
|
||||
AttentionActivations::kThreadReplicationFactor /
|
||||
layer_config.heads),
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention
|
||||
// is still MHA and does not use an external KV cache.
|
||||
q(MatFactory("q", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
|
|
@ -86,6 +90,9 @@ struct AttentionActivations {
|
|||
att_out(MatFactory("att_out", batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
att_out_reps(MatFactory("att_out", batch_size * rep_factor,
|
||||
layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
|
||||
allocator)),
|
||||
softmax_d(
|
||||
|
|
@ -107,6 +114,11 @@ struct AttentionActivations {
|
|||
}
|
||||
return;
|
||||
}
|
||||
// This is a guess at the maximum number of params we might need to avoid
|
||||
// reallocations. The actual number of params is determined by the number of
|
||||
// query tiles, which is not known here.
|
||||
flash_params.reserve(batch_size * layer_config.heads);
|
||||
split_flash_params.reserve(batch_size * layer_config.heads);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
|
|
@ -130,6 +142,7 @@ struct AttentionActivations {
|
|||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_out_reps.OverrideRows(batch_size * rep_factor);
|
||||
softmax_max.OverrideRows(batch_size);
|
||||
softmax_d.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
|
|
@ -137,6 +150,15 @@ struct AttentionActivations {
|
|||
// `inv_timescale*` are not batched.
|
||||
}
|
||||
|
||||
// Maximum factor by which we might scale-up work to maximize parallelism.
|
||||
size_t rep_factor = 1;
|
||||
// Parameters for flash attention. The size of the vector is somewhere between
|
||||
// the number of query rows and 1/8th of that.
|
||||
std::vector<FlashAttentionParams> flash_params;
|
||||
// Parameters for flash attention, split by k-position. May be significantly
|
||||
// larger than flash_params in decode mode, when the number of query rows is
|
||||
// small.
|
||||
std::vector<FlashAttentionParams> split_flash_params;
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<BF16> q_bf;
|
||||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||
|
|
@ -148,6 +170,7 @@ struct AttentionActivations {
|
|||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
MatStorageT<float> att_out_reps; // attention output for each thread.
|
||||
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
|
||||
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
|
||||
// Accumulation of attention outputs over heads
|
||||
|
|
@ -156,19 +179,27 @@ struct AttentionActivations {
|
|||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
// Replication factor to help evenly share work over threads.
|
||||
static constexpr size_t kThreadReplicationFactor = 4;
|
||||
};
|
||||
|
||||
// A non-owning view of AttentionActivations.
|
||||
struct AttentionActivationsPtrs {
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
|
||||
AttentionActivationsPtrs(
|
||||
const ModelConfig& config, size_t seq_len,
|
||||
std::vector<FlashAttentionParams>& flash_params,
|
||||
std::vector<FlashAttentionParams>& split_flash_params)
|
||||
: config(config),
|
||||
flash_params(flash_params),
|
||||
split_flash_params(split_flash_params),
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
|
||||
query_scale(ChooseQueryScale(config)) {}
|
||||
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
|
||||
const AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len) {
|
||||
AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len, activations.flash_params,
|
||||
activations.split_flash_params) {
|
||||
q = activations.q;
|
||||
q_bf = activations.q_bf;
|
||||
q_T = activations.q_T;
|
||||
|
|
@ -178,6 +209,7 @@ struct AttentionActivationsPtrs {
|
|||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
att_out = activations.att_out;
|
||||
att_out_reps = activations.att_out_reps;
|
||||
softmax_max = activations.softmax_max;
|
||||
softmax_d = activations.softmax_d;
|
||||
att_sums = activations.att_sums;
|
||||
|
|
@ -208,6 +240,9 @@ struct AttentionActivationsPtrs {
|
|||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
// Parameters for flash attention.
|
||||
std::vector<FlashAttentionParams>& flash_params;
|
||||
std::vector<FlashAttentionParams>& split_flash_params;
|
||||
|
||||
// For the matrices below, the batch_size dimension is really qbatch.Size() *
|
||||
// token_batch_size, but in all known uses, one of those is 1. Specifically,
|
||||
|
|
@ -233,6 +268,7 @@ struct AttentionActivationsPtrs {
|
|||
// Attention output computed from att * V, size batch_size x (q_heads *
|
||||
// qkv_dim).
|
||||
MatPtrT<float> att_out;
|
||||
MatPtrT<float> att_out_reps;
|
||||
// The maximum logit value encountered when computing att_out from att,
|
||||
// size batch_size x q_heads . See OnlineSoftmaxState for details.
|
||||
// WARNING: Only filled in for AttentionImpl::kOld.
|
||||
|
|
@ -287,7 +323,8 @@ struct Activations {
|
|||
s_w_linear_w(config.num_layers, max_workers),
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
runtime_config, ctx.allocator, row_ptrs),
|
||||
runtime_config, ctx.pools.MaxWorkers(), ctx.allocator,
|
||||
row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,39 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// Returns the number of floats per vector (aka NF).
|
||||
size_t FloatsPerVector() {
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
return hn::Lanes(df);
|
||||
}
|
||||
|
||||
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
|
||||
// done already, reshape it to take NF into account.
|
||||
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
|
||||
if (kv.Cols() > cache.Cols()) {
|
||||
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
|
||||
}
|
||||
}
|
||||
|
||||
// Transposes a single row of the kv cache into the k-cache and v-cache.
|
||||
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
|
||||
KV_t* HWY_RESTRICT v, size_t qkv_dim) {
|
||||
// This is inefficient, as the writes are scattered over cache lines, but it
|
||||
// is a tiny fraction of the overall computation, and it is linear in the
|
||||
// token length.
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
for (size_t i = 0; i < qkv_dim; i += 2) {
|
||||
k[i * kFloatsPerTile] = kv[i];
|
||||
k[i * kFloatsPerTile + 1] = kv[i + 1];
|
||||
}
|
||||
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
|
||||
for (size_t j = 0; j < kFloatsPerTile; j++) {
|
||||
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
|
|
@ -280,6 +313,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
/*add=*/nullptr, env, kv_rows);
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
|
||||
}
|
||||
const size_t kFloatsPerVector = FloatsPerVector();
|
||||
|
||||
// Apply positional encodings for K.
|
||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
|
|
@ -299,6 +337,26 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
// Note that k_cache and v_cache are different shapes.
|
||||
// The innermost dimension of k is 2 values from qkv_dim because they
|
||||
// are going to be used in a BF16 dot product involving pairs of
|
||||
// values over NF k positions.
|
||||
// The innermost dimension of v is 2NF values from qkv_dim because they
|
||||
// will be loaded into a BF16 vector to be scaled and added to the
|
||||
// cached attention output in 2 NF-sized registers.
|
||||
// TODO(rays): factor out these calculations into functions.
|
||||
auto& k_cache = qbatch.KV(qi).k_cache;
|
||||
KV_t* HWY_RESTRICT k =
|
||||
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
|
||||
kFloatsPerVector +
|
||||
(cache_pos % (2 * kFloatsPerVector)) * 2;
|
||||
auto& v_cache = qbatch.KV(qi).v_cache;
|
||||
KV_t* HWY_RESTRICT v =
|
||||
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
|
||||
kFloatsPerVector +
|
||||
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
|
||||
|
||||
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -319,6 +377,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
/*mul=*/1.0f);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
// This is inefficient, as multiple threads are writing the same K
|
||||
// cache line, but the input is generated by a matmul, so it is
|
||||
// difficult to change, and it probably isn't significant.
|
||||
TransposeKVCacheRow(kv, k, v, qkv_dim);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -341,7 +403,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
} else {
|
||||
// * 2 does not help on Turin.
|
||||
FlashAttention(num_tokens,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() *
|
||||
AttentionActivations::kThreadReplicationFactor,
|
||||
layer_idx, layer.query_norm_scale, activations, qbatch,
|
||||
env.ctx, attention_impl);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,13 @@ namespace gcpp {
|
|||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
size_t FloatsPerVector(); \
|
||||
\
|
||||
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
|
||||
\
|
||||
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
|
||||
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
|
||||
\
|
||||
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
ThreadingContext& ctx, size_t worker, size_t pos, \
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <cstring> // strcmp
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
|
@ -105,7 +107,8 @@ struct TestAttentionState {
|
|||
tokens(num_tokens),
|
||||
attention_storage_(model_state.config, model_state.layer_config,
|
||||
batch_size, num_tokens, runtime_config,
|
||||
state.ctx.allocator, row_ptrs_),
|
||||
state.ctx.pools.MaxWorkers(), state.ctx.allocator,
|
||||
row_ptrs_),
|
||||
attention(model_state.config, num_tokens, attention_storage_) {
|
||||
for (size_t i = 0; i < qbatch_size; ++i) {
|
||||
kv_caches.emplace_back(model_state.config, inference_args,
|
||||
|
|
@ -143,6 +146,7 @@ struct TestAttentionState {
|
|||
};
|
||||
|
||||
double GetTolerance() {
|
||||
if (IsBF16<KV_t>()) return 1e-2;
|
||||
const char* target_name = hwy::TargetName(HWY_TARGET);
|
||||
if (strncmp(target_name, "AVX2", 4) == 0) {
|
||||
return 2e-2;
|
||||
|
|
@ -155,6 +159,57 @@ double GetTolerance() {
|
|||
}
|
||||
}
|
||||
|
||||
// Comparison function for computations that used BF16, whether the result is
|
||||
// stored in BF16 or F32.
|
||||
// Compare with absolute tolerance for values with small magnitudes.
|
||||
// Compare with relative tolerance for values with larger magnitudes.
|
||||
template <typename T>
|
||||
bool CompareArraySimilarBF16(const T* expected, const T* actual, size_t count,
|
||||
const char* target_name, const char* filename,
|
||||
int line) {
|
||||
constexpr double kTolerance = 3e-2;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
const double exp = hwy::ConvertScalarTo<double>(expected[i]);
|
||||
const double act = hwy::ConvertScalarTo<double>(actual[i]);
|
||||
const double l1 = std::abs(act - exp);
|
||||
// Cannot divide, so check absolute error.
|
||||
if (std::abs(exp) <= 1.0) {
|
||||
if (l1 > kTolerance) {
|
||||
std::string array_values = hwy::detail::FormatMismatchedArrays(
|
||||
expected, actual, count, kTolerance);
|
||||
HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E l1 %E tol %E%s\n",
|
||||
target_name, filename, line, "BF16", i, count, exp, act, l1,
|
||||
kTolerance, array_values.c_str());
|
||||
return false;
|
||||
}
|
||||
} else { // relative
|
||||
const double rel = l1 / exp;
|
||||
if (rel > kTolerance) {
|
||||
std::string array_values = hwy::detail::FormatMismatchedArrays(
|
||||
expected, actual, count, kTolerance);
|
||||
HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E rel %E tol %E%s\n",
|
||||
target_name, filename, line, "BF16", i, count, exp, act, rel,
|
||||
kTolerance, array_values.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CompareArraySimilar(const T* expected, const T* actual, size_t count,
|
||||
const char* target_name, const char* filename,
|
||||
int line) {
|
||||
if constexpr (IsBF16<KV_t>()) {
|
||||
return CompareArraySimilarBF16(expected, actual, count, target_name,
|
||||
filename, line);
|
||||
} else {
|
||||
return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(),
|
||||
target_name, filename, line);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
|
||||
void CompareAttSumsWithGolden(
|
||||
const AttentionActivationsPtrs& attention,
|
||||
|
|
@ -170,9 +225,9 @@ void CompareAttSumsWithGolden(
|
|||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
|
||||
}
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(),
|
||||
kDims, hwy::TargetName(HWY_TARGET),
|
||||
__FILE__, __LINE__))
|
||||
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
|
||||
}
|
||||
}
|
||||
|
|
@ -200,19 +255,20 @@ void CompareKVCacheWithGolden(
|
|||
|
||||
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
|
||||
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
|
||||
const float* cache_row =
|
||||
const BF16* cache_row =
|
||||
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
|
||||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_k_row[j] = cache_row[kv_offset + j];
|
||||
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
|
||||
actual_k_row[j] = hwy::ConvertScalarTo<float>(cache_row[kv_offset + j]);
|
||||
actual_v_row[j] =
|
||||
hwy::ConvertScalarTo<float>(cache_row[kv_offset + qkv_dim + j]);
|
||||
}
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
|
||||
EXPECT_TRUE(CompareArraySimilar(
|
||||
k_golden[token_idx][qi], actual_k_row.get(), kDims,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " kv_head=" << kv_head;
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
|
||||
EXPECT_TRUE(CompareArraySimilar(
|
||||
v_golden[token_idx][qi], actual_v_row.get(), kDims,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " kv_head=" << kv_head;
|
||||
|
|
@ -238,8 +294,8 @@ void CompareQVecsWithGolden(
|
|||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_q_row[j] = q_row[head_offset + j];
|
||||
}
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
|
||||
EXPECT_TRUE(CompareArraySimilar(
|
||||
q_golden[token_idx][qi], actual_q_row.get(), kDims,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " q_head=" << q_head;
|
||||
|
|
@ -267,42 +323,42 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
|||
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
|
||||
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
|
||||
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
|
||||
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
|
||||
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875,
|
||||
128, 27.25, -161, 19.125, -58, 97.5},
|
||||
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
|
||||
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
|
||||
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
|
||||
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
|
||||
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
|
||||
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
|
||||
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
|
||||
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125,
|
||||
6.4625, 150, 144, -155, -47.25, -98.5, 3.5625},
|
||||
{-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
|
||||
-66.5, -0.62165625, -46.5, -152, -2.9375, -81}},
|
||||
{{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75,
|
||||
-46, -22, -19.375, -16.125, -148, 20.875},
|
||||
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
|
||||
{-47, -19.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
|
||||
10.9375, -52.5, 23.25, 75}},
|
||||
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
|
||||
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25,
|
||||
-34.75, 18, -52, 100, -186, -75.5, 50.75},
|
||||
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
|
||||
{7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
|
||||
39.25, 65, 47.25, -89.5, -34.25, 137}},
|
||||
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
|
||||
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
|
||||
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
|
||||
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
|
||||
32.5, 53.75, 109, 4.62375, 57.5, -20.5, 132},
|
||||
{143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5,
|
||||
195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}},
|
||||
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
|
||||
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
|
||||
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
|
||||
{137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5,
|
||||
88.5, 48.5, -46.25, -36.75, 101.5}},
|
||||
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
|
||||
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5,
|
||||
125, -53.5, -14.625, 262},
|
||||
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
|
||||
{28.075, 6.64375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
|
||||
81, -27.25, -3.03125, 111, -167, 59, 176}},
|
||||
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
|
||||
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
|
||||
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
|
||||
{40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5,
|
||||
58.25, 173, -53.5, 25.125, 4.8125}},
|
||||
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
|
||||
-17.125, 15.25, 143, -27, 103, 101},
|
||||
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
|
||||
8.125, -99.5, 13.6875, -11.6875, 33}},
|
||||
8.55, -99.5, 14.6875, -11.6875, 33}},
|
||||
};
|
||||
|
||||
// Layer 0, *K*V Head 0
|
||||
|
|
|
|||
|
|
@ -81,8 +81,8 @@ static inline bool EnumValid(LayerAttentionType type) {
|
|||
}
|
||||
|
||||
enum class AttentionImpl {
|
||||
kOld,
|
||||
kFlash,
|
||||
kOld, // Previous Attention implementation
|
||||
kFlash, // Flash Attention (default)
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -44,15 +44,6 @@ namespace gcpp {
|
|||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
|
||||
ThreadingContext& ctx, const size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
|
|
|
|||
|
|
@ -62,16 +62,17 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
void SetMat(const size_t offset, MatPtrT<float>& mat) {
|
||||
template <typename T>
|
||||
void SetMat(const size_t offset, MatPtrT<T>& mat) {
|
||||
const size_t kOuter = mat.Extents().rows;
|
||||
const size_t kInner = mat.Extents().cols;
|
||||
const float i_scale = 1.0f / kInner;
|
||||
const float j_scale = 1.0f / kOuter;
|
||||
for (size_t i = 0; i < kOuter; ++i) {
|
||||
float* row = mat.Row(i);
|
||||
T* row = mat.Row(i);
|
||||
for (size_t j = 0; j < kInner; ++j) {
|
||||
row[j] =
|
||||
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
|
||||
row[j] = hwy::ConvertScalarTo<T>(
|
||||
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -94,14 +95,15 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
|
|||
if (rel_abs_delta > 0.0f) {
|
||||
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
|
||||
}
|
||||
EXPECT_LT(rel_abs_delta, 1e-5)
|
||||
EXPECT_LT(rel_abs_delta, 1e-3)
|
||||
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
|
||||
<< c << "]=" << b_row[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestFlashAttention(size_t target_parallelism) {
|
||||
void TestFlashAttention(size_t target_parallelism,
|
||||
AttentionImpl attention_impl) {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
constexpr size_t kOuter = 1024;
|
||||
|
|
@ -131,7 +133,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
const size_t batch_size = kOuter;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||
kOuter, runtime_config, ctx.allocator,
|
||||
kOuter, runtime_config,
|
||||
ctx.pools.MaxWorkers(), ctx.allocator,
|
||||
row_ptrs);
|
||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
|
@ -142,7 +145,10 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(attention.div_seq_len.GetDivisor());
|
||||
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
|
||||
auto& kvc = qbatch.KV(0).kv_cache;
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||
// Make strided views into the kv cache for
|
||||
// this query and head.
|
||||
|
|
@ -153,6 +159,17 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
|
||||
SetMat(h + layer_config.heads, k);
|
||||
SetMat(h + layer_config.heads * 2, v);
|
||||
for (size_t p = 0; p < tokens.size(); ++p) {
|
||||
KV_t* HWY_RESTRICT k_src = k.Row(p);
|
||||
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
|
||||
head_offset * kFloatsPerTile / 2 +
|
||||
p % kFloatsPerTile * 2;
|
||||
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
|
||||
head_offset * kFloatsPerTile / 2 +
|
||||
p % kFloatsPerTile * kFloatsPerTile;
|
||||
|
||||
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
|
||||
}
|
||||
}
|
||||
SetMat(1, attention.q);
|
||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
|
||||
|
|
@ -167,18 +184,19 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
|
||||
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
|
||||
total_tasks, target_parallelism);
|
||||
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
||||
target_parallelism, kNF, kVTileSize);
|
||||
printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n",
|
||||
target_parallelism, kNF, kVTileSize,
|
||||
GetAttentionImplName(attention_impl).c_str());
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
|
||||
attention, qbatch, ctx, AttentionImpl::kFlash);
|
||||
attention, qbatch, ctx, attention_impl);
|
||||
AssertClose(attention.att_out, *saved_att);
|
||||
ctx.profiler.PrintResults();
|
||||
}
|
||||
|
||||
void TestAttention() {
|
||||
TestFlashAttention(8192);
|
||||
TestFlashAttention(2048);
|
||||
TestFlashAttention(256);
|
||||
TestFlashAttention(8192, AttentionImpl::kFlash);
|
||||
TestFlashAttention(2048, AttentionImpl::kFlash);
|
||||
TestFlashAttention(256, AttentionImpl::kFlash);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,19 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// The vertical tile size in flash attention when register lanes correspond to
|
||||
// K-timesteps, and the number of registers is 4 for 4 Q-rows.
|
||||
static constexpr size_t k4xNFVTileSize = 4;
|
||||
// The vertical tile size in flash attention when register lanes correspond to
|
||||
// K-timesteps, and the number of registers is 8 for 8 Q-rows.
|
||||
static constexpr size_t k8xNFVTileSize = 8;
|
||||
|
||||
// State for computing softmax in a streaming ("online") manner,
|
||||
// avoiding large intermediate values by subtracting the running maximum.
|
||||
// For a sequence x_1, ..., x_n:
|
||||
|
|
@ -20,10 +28,44 @@ struct OnlineSoftmaxState {
|
|||
float d = 0.0f;
|
||||
};
|
||||
|
||||
static constexpr size_t kVTileSize4 = 4;
|
||||
|
||||
struct Tile4FlashState {
|
||||
OnlineSoftmaxState row_states[kVTileSize4];
|
||||
OnlineSoftmaxState row_states[k8xNFVTileSize];
|
||||
};
|
||||
|
||||
// Parameters for a strip of tiles of flash attention. For processing a strip
|
||||
// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF
|
||||
// k-positions. The total width of the strip might cover the entire sequence,
|
||||
// or a part of it, depending on whether the strip has been split.
|
||||
struct FlashAttentionParams {
|
||||
// Vertical tile size gives the number used in the k8xNFVTileSize arrays.
|
||||
// It is the number of Q rows in the tile.
|
||||
uint32_t v_tile_size = 0;
|
||||
// min start position across all rows in the tile determines the
|
||||
// mask used for the tile.
|
||||
uint32_t min_start_pos = std::numeric_limits<uint32_t>::max();
|
||||
// max last position across all rows in the tile determines the mask
|
||||
// used for the tile.
|
||||
uint32_t max_last_pos = 0;
|
||||
// Index into the qbatch.KV is the same for each row in the tile.
|
||||
uint32_t qi_index;
|
||||
// Index into the kv_cache is the same for each row in the tile.
|
||||
uint32_t kv_offset;
|
||||
// In the original task, the index to the split tasks of the first split task.
|
||||
uint32_t split_index = 0;
|
||||
// The index of the split for running split attention.
|
||||
uint32_t i_of_n = 0;
|
||||
// Offsets into original Q for each row in the tile.
|
||||
uint32_t q_offsets[k8xNFVTileSize];
|
||||
// Offsets into att_out for each row in the tile.
|
||||
uint32_t out_offsets[k8xNFVTileSize];
|
||||
// Start k-positions for each row in the tile.
|
||||
uint32_t start_pos[k8xNFVTileSize];
|
||||
// Last k-positions for each row in the tile. Inclusive.
|
||||
uint32_t last_pos[k8xNFVTileSize];
|
||||
// Row index to att_out.
|
||||
uint32_t tq_idx[k8xNFVTileSize];
|
||||
// Flash attention state for the tile.
|
||||
Tile4FlashState end_state;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -43,6 +43,17 @@ static size_t CappedSeqLen(const ModelConfig& config,
|
|||
|
||||
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
|
||||
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
// WARNING: the rows and cols of k_cache and v_cache will be modified
|
||||
// before use!
|
||||
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
|
||||
// cols will be increased by 2xkFloatsPerVector on first use. This is to
|
||||
// avoid making KVCache another class that has to be duplicated for each
|
||||
// machine architecture, since kFloatsPerVector is architecture dependent.
|
||||
// The change is shape is safe only if the padding is kPacked.
|
||||
k_cache("k", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator,
|
||||
MatPadding::kPacked),
|
||||
v_cache("v", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator,
|
||||
MatPadding::kPacked),
|
||||
allocator_(allocator) {}
|
||||
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
|
|
@ -55,6 +66,8 @@ KVCache KVCache::Copy() {
|
|||
KVCache copy(kv_cache.Extents(), allocator_);
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
CopyMat(k_cache, copy.k_cache);
|
||||
CopyMat(v_cache, copy.v_cache);
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
using KV_t = float;
|
||||
using KV_t = BF16;
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
struct KVCachePtr {
|
||||
|
|
@ -38,6 +38,8 @@ struct KVCachePtr {
|
|||
size_t SeqLen() const;
|
||||
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
MatPtrT<KV_t> k_cache;
|
||||
MatPtrT<KV_t> v_cache;
|
||||
};
|
||||
|
||||
struct KVCache {
|
||||
|
|
@ -52,10 +54,33 @@ struct KVCache {
|
|||
}
|
||||
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
// The format of k_cache indicates that there are pairs of values from
|
||||
// qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence,
|
||||
// in groups of qkv_dim/2 elements in groups of kv_heads elements.
|
||||
// This enables sequential loading of the data when filling 2 vectors with
|
||||
// NF sequence elements of pairs of BF16 qkv values. The next vector then
|
||||
// continues reading the rest of qkv.
|
||||
// [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2]
|
||||
MatStorageT<KV_t> k_cache;
|
||||
// v_cache is formatted to allow sequential access to V during scaling and
|
||||
// update of att_out.
|
||||
// Originally [seq_len, layers * kv_heads * qkv_dim]
|
||||
// v_cache is transposed to:
|
||||
// [layers, kv_heads, seq_len, qkv_dim], reshaped to:
|
||||
// [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF]
|
||||
// then transposed to:
|
||||
// [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF]
|
||||
// and finally packed in a 2D MatStorageT as:
|
||||
// [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF]
|
||||
// This allows sequential reads of 2NF registers each of 2NF BF16 values,
|
||||
// repeatedly until all of qkv_dim is read.
|
||||
MatStorageT<KV_t> v_cache;
|
||||
|
||||
KVCachePtr ToPtr() {
|
||||
return KVCachePtr{
|
||||
.kv_cache = kv_cache,
|
||||
.k_cache = k_cache,
|
||||
.v_cache = v_cache,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
587
ops/ops-inl.h
587
ops/ops-inl.h
|
|
@ -614,267 +614,6 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
|
|||
});
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
|
||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
|
||||
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
|
||||
VF& sum9, VF& sum10, VF& sum11,
|
||||
VF& sum12, VF& sum13, VF& sum14,
|
||||
VF& sum15) {
|
||||
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
|
||||
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
|
||||
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
|
||||
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
|
||||
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
|
||||
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
|
||||
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
|
||||
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
|
||||
sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale));
|
||||
sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale));
|
||||
sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale));
|
||||
sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale));
|
||||
sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale));
|
||||
sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale));
|
||||
sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale));
|
||||
sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale));
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
|
||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
|
||||
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
|
||||
VF& sum9, VF& sum10, VF& sum11,
|
||||
VF& sum12, VF& sum13, VF& sum14,
|
||||
VF& sum15) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
|
||||
VF& sum6, VF& sum7) {
|
||||
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
|
||||
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
|
||||
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
|
||||
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
|
||||
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
|
||||
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
|
||||
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
|
||||
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
|
||||
VF& sum6, VF& sum7) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
|
||||
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
|
||||
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
|
||||
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
|
||||
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
|
||||
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
|
||||
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
|
||||
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
|
||||
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
|
||||
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
|
||||
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
|
||||
sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8);
|
||||
sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9);
|
||||
sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10);
|
||||
sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11);
|
||||
sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12);
|
||||
sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13);
|
||||
sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14);
|
||||
sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
|
||||
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
|
||||
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& sum4, VF& sum5, VF& sum6,
|
||||
VF& sum7) {
|
||||
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
|
||||
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
|
||||
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
|
||||
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
|
||||
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
|
||||
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
|
||||
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
|
||||
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& sum4, VF& sum5, VF& sum6,
|
||||
VF& sum7) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
|
||||
VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3) {
|
||||
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
|
||||
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
|
||||
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
|
||||
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
|
||||
}
|
||||
|
||||
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
|
||||
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
|
||||
// after first prescaling out by scale.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
||||
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
if HWY_LANES_CONSTEXPR (NF == 16) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
out8 = hn::Load(df, out + i + out_offsets[8]);
|
||||
out9 = hn::Load(df, out + i + out_offsets[9]);
|
||||
out10 = hn::Load(df, out + i + out_offsets[10]);
|
||||
out11 = hn::Load(df, out + i + out_offsets[11]);
|
||||
out12 = hn::Load(df, out + i + out_offsets[12]);
|
||||
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
hn::Store(out8, df, out + i + out_offsets[8]);
|
||||
hn::Store(out9, df, out + i + out_offsets[9]);
|
||||
hn::Store(out10, df, out + i + out_offsets[10]);
|
||||
hn::Store(out11, df, out + i + out_offsets[11]);
|
||||
hn::Store(out12, df, out + i + out_offsets[12]);
|
||||
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 8) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 4) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
|
||||
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
|
||||
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
|
||||
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||
MulAdd4(df, x0, c0, out0, out1, out2, out3);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||
MulAdd4(df, x1, c1, out0, out1, out2, out3);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||
MulAdd4(df, x2, c2, out0, out1, out2, out3);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||
MulAdd4(df, x3, c3, out0, out1, out2, out3);
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||
MulAdd4(df, x4, c4, out0, out1, out2, out3);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||
MulAdd4(df, x5, c5, out0, out1, out2, out3);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||
MulAdd4(df, x6, c6, out0, out1, out2, out3);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||
MulAdd4(df, x7, c7, out0, out1, out2, out3);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
}
|
||||
i += NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
|
||||
const VF c1, const VF c2, const VF c3,
|
||||
|
|
@ -887,240 +626,134 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
|
|||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0,
|
||||
const VF c1, const VF c2,
|
||||
const VF c3, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
|
||||
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
|
||||
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
|
||||
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
|
||||
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
|
||||
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
|
||||
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
|
||||
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
|
||||
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4(
|
||||
DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c,
|
||||
const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a,
|
||||
VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) {
|
||||
using DBF = hn::ScalableTag<BF16>;
|
||||
const DBF dbf;
|
||||
using VBF = hn::Vec<DBF>;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
|
||||
VBF v0 = hn::Load(dbf, v);
|
||||
VF c0 = hn::Set(df, *c++);
|
||||
VF c1 = hn::Set(df, *c++);
|
||||
VF c2 = hn::Set(df, *c++);
|
||||
VF c3 = hn::Set(df, *c++);
|
||||
VF v0a = hn::PromoteLowerTo(df, v0);
|
||||
VF v0b = hn::PromoteUpperTo(df, v0);
|
||||
MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a);
|
||||
MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b);
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
|
||||
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
|
||||
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
|
||||
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
|
||||
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
|
||||
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
|
||||
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
|
||||
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
|
||||
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
|
||||
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
|
||||
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
|
||||
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
|
||||
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
|
||||
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
|
||||
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
|
||||
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
|
||||
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
|
||||
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
|
||||
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
|
||||
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
|
||||
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
|
||||
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
|
||||
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
|
||||
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
|
||||
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
|
||||
|
||||
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
|
||||
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
|
||||
// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows
|
||||
// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out,
|
||||
// after first prescaling out by scale.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
// The depth (size) must be a multiple of 2NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
|
||||
const VF c2, const VF c3, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
|
||||
const VF c10, const VF c11, const VF c20, const VF c21, const VF c30,
|
||||
const VF c31, const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos,
|
||||
size_t num_lanes, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
|
||||
HWY_DASSERT(pos[0] % (2 * NF) == 0);
|
||||
HWY_ALIGN float c_mem[8 * kMaxNF];
|
||||
hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem);
|
||||
hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
|
||||
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
|
||||
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
|
||||
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
|
||||
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
|
||||
if HWY_LANES_CONSTEXPR (NF >= 8) {
|
||||
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
|
||||
if HWY_LANES_CONSTEXPR (NF >= 16) {
|
||||
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
|
||||
out3);
|
||||
}
|
||||
}
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
i += NF;
|
||||
while (i + NF * 2 <= size) {
|
||||
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
|
||||
out0a = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1a = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2a = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3a = hn::Load(df, out + i + out_offsets[3]);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
VF scale1 = hn::Set(df, scales[1]);
|
||||
VF scale2 = hn::Set(df, scales[2]);
|
||||
VF scale3 = hn::Set(df, scales[3]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
out1a = hn::Mul(out1a, scale1);
|
||||
out2a = hn::Mul(out2a, scale2);
|
||||
out3a = hn::Mul(out3a, scale3);
|
||||
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
|
||||
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
|
||||
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
|
||||
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
out1b = hn::Mul(out1b, scale1);
|
||||
out2b = hn::Mul(out2b, scale2);
|
||||
out3b = hn::Mul(out3b, scale3);
|
||||
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
|
||||
out2a, out3a, out0b, out1b, out2b, out3b);
|
||||
hn::Store(out0a, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1a, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2a, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3a, df, out + i + out_offsets[3]);
|
||||
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
|
||||
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
|
||||
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
|
||||
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
|
||||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
|
||||
// corresponding values in c0 and adds them to the NF rows of out.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
|
||||
const size_t pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df,
|
||||
const BF16* HWY_RESTRICT v,
|
||||
const float* HWY_RESTRICT c,
|
||||
const size_t num_lanes,
|
||||
VF& sum0a, VF& sum0b) {
|
||||
using DBF = hn::ScalableTag<BF16>;
|
||||
const DBF dbf;
|
||||
using VBF = hn::Vec<DBF>;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
|
||||
VBF v0 = hn::Load(dbf, v);
|
||||
VF c0 = hn::Set(df, *c++);
|
||||
VF v0a = hn::PromoteLowerTo(df, v0);
|
||||
VF v0b = hn::PromoteUpperTo(df, v0);
|
||||
sum0a = hn::MulAdd(v0a, c0, sum0a);
|
||||
sum0b = hn::MulAdd(v0b, c0, sum0b);
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
|
||||
const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos, size_t num_lanes,
|
||||
float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
|
||||
HWY_DASSERT(pos[0] % (2 * NF) == 0);
|
||||
HWY_ALIGN float c_mem[2 * kMaxNF];
|
||||
hn::Store(c00, df, c_mem);
|
||||
hn::Store(c01, df, c_mem + NF);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
if HWY_LANES_CONSTEXPR (NF == 16) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
out8 = hn::Load(df, out + i + out_offsets[8]);
|
||||
out9 = hn::Load(df, out + i + out_offsets[9]);
|
||||
out10 = hn::Load(df, out + i + out_offsets[10]);
|
||||
out11 = hn::Load(df, out + i + out_offsets[11]);
|
||||
out12 = hn::Load(df, out + i + out_offsets[12]);
|
||||
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
hn::Store(out8, df, out + i + out_offsets[8]);
|
||||
hn::Store(out9, df, out + i + out_offsets[9]);
|
||||
hn::Store(out10, df, out + i + out_offsets[10]);
|
||||
hn::Store(out11, df, out + i + out_offsets[11]);
|
||||
hn::Store(out12, df, out + i + out_offsets[12]);
|
||||
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 8) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 4) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
|
||||
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
|
||||
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
|
||||
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd4(df, x0, c0, out0, out1, out2, out3);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
}
|
||||
i += NF;
|
||||
while (i + NF * 2 <= size) {
|
||||
VF out0a, out0b;
|
||||
out0a = hn::Load(df, out + i + out_offsets[0]);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
|
||||
hn::Store(out0a, df, out + i + out_offsets[0]);
|
||||
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
|
||||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
|
|
|||
11
util/mat.h
11
util/mat.h
|
|
@ -202,6 +202,17 @@ class MatPtr : public IFields {
|
|||
override_rows_ = static_cast<uint32_t>(rows);
|
||||
}
|
||||
|
||||
// Changes the number of rows and columns without reallocating the memory.
|
||||
// Increases cols by factor and reduces rows by factor.
|
||||
// The rows must be divisible by factor and the matrix must be packed.
|
||||
void ReshapePackedRowsToCols(size_t factor) {
|
||||
HWY_ASSERT(IsPacked());
|
||||
HWY_ASSERT(private_rows_ % factor == 0);
|
||||
private_rows_ /= factor;
|
||||
cols_ *= factor;
|
||||
stride_ *= factor;
|
||||
}
|
||||
|
||||
// Offset by which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -106,7 +106,8 @@ template <typename T>
|
|||
void FillMatPtrT(MatPtrT<T>& mat) {
|
||||
for (int i = 0; i < mat.Rows(); ++i) {
|
||||
for (int j = 0; j < mat.Cols(); ++j) {
|
||||
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
mat.Row(i)[j] =
|
||||
hwy::ConvertScalarTo<T>(hwy::Unpredictable1() * 0.01f * (i + j + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) {
|
|||
return "FlashAttention.Inclusive";
|
||||
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
|
||||
return "FlashAttention.RMSNormAndPositionalEncoding";
|
||||
case Zones::kFlashAttentionSingleFlashAttention:
|
||||
return "FlashAttention.SingleFlashAttention";
|
||||
case Zones::kFlashAttentionTileFlashAttention:
|
||||
return "FlashAttention.TileFlashAttention";
|
||||
case Zones::kFlashAttentionTileFlashAttention1:
|
||||
return "FlashAttention.TileFlashAttention1";
|
||||
case Zones::kFlashAttentionTileFlashAttention4:
|
||||
return "FlashAttention.TileFlashAttention4";
|
||||
case Zones::kFlashAttentionTransposeQ:
|
||||
return "FlashAttention.TransposeQ";
|
||||
case Zones::kFlashAttentionTileFlashAttention8:
|
||||
return "FlashAttention.TileFlashAttention8";
|
||||
case Zones::kFlashAttentionCombineSplit:
|
||||
return "FlashAttention.CombineSplit";
|
||||
case Zones::kGenActivation:
|
||||
return "Gen.Activation";
|
||||
case Zones::kGenActivationFused:
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ enum class Zones { // Keep sorted
|
|||
kFlashAttentionFlashAttention,
|
||||
kFlashAttentionInclusive,
|
||||
kFlashAttentionRmsNormAndPositionalEncoding,
|
||||
kFlashAttentionSingleFlashAttention,
|
||||
kFlashAttentionTileFlashAttention,
|
||||
kFlashAttentionTileFlashAttention1,
|
||||
kFlashAttentionTileFlashAttention4,
|
||||
kFlashAttentionTransposeQ,
|
||||
kFlashAttentionTileFlashAttention8,
|
||||
kFlashAttentionCombineSplit,
|
||||
kGenActivation,
|
||||
kGenActivationFused,
|
||||
kGenAttention,
|
||||
|
|
|
|||
Loading…
Reference in New Issue