mirror of https://github.com/google/gemma.cpp.git
123 lines
4.2 KiB
C++
123 lines
4.2 KiB
C++
#include <stddef.h>
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <tuple>
|
|
#include <vector>
|
|
|
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
|
#include "gemma/activations.h"
|
|
#include "gemma/configs.h"
|
|
#include "gemma/gemma.h"
|
|
#include "gemma/gemma_args.h"
|
|
#include "gemma/kv_cache.h"
|
|
#include "gemma/kv_transcoding.h"
|
|
#include "gemma/weights.h"
|
|
#include "util/mat.h"
|
|
#include "util/threading_context.h"
|
|
#ifndef HWY_DISABLED_TARGETS
|
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
|
#endif // HWY_DISABLED_TARGETS
|
|
|
|
// clang-format off
|
|
#undef HWY_TARGET_INCLUDE
|
|
#define HWY_TARGET_INCLUDE "gemma/tiled_attention_test.cc" // NOLINT
|
|
// clang-format on
|
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
|
#include "hwy/highway.h"
|
|
// After highway.h
|
|
#include "gemma/tiled_attention.h"
|
|
#include "util/test_util.h"
|
|
#include "hwy/aligned_allocator.h"
|
|
#include "hwy/tests/test_util-inl.h"
|
|
|
|
HWY_BEFORE_NAMESPACE();
|
|
namespace gcpp {
|
|
namespace HWY_NAMESPACE {
|
|
|
|
using ::testing::FloatNear;
|
|
using ::testing::Pointwise;
|
|
|
|
struct AttentionTestEnv {
|
|
AttentionTestEnv(
|
|
size_t qkv_dim, size_t kv_seq_len, size_t attention_window_size,
|
|
size_t num_kv_heads, size_t num_heads, size_t num_tokens, size_t last_pos,
|
|
float att_cap, size_t layer_idx, size_t layers_total, size_t qbatch_size,
|
|
AttentionImpl attention_impl,
|
|
std::optional<Type> kv_cache_type = {} )
|
|
: ctx(threading_args), env(ctx) {
|
|
layer_config.heads = num_heads;
|
|
layer_config.kv_heads = num_kv_heads;
|
|
layer_config.qkv_dim = qkv_dim;
|
|
layer_config.model_dim = qkv_dim * num_heads;
|
|
|
|
model_config.attention_window_sizes = {
|
|
static_cast<uint32_t>(attention_window_size)};
|
|
model_config.att_cap = att_cap;
|
|
model_config.max_seq_len = kv_seq_len;
|
|
model_config.num_layers = layers_total;
|
|
model_config.model_dim = layer_config.model_dim;
|
|
model_config.vocab_size = 1; // not vit
|
|
|
|
for (size_t i = 0; i < model_config.num_layers; ++i) {
|
|
model_config.layer_configs.push_back(layer_config);
|
|
}
|
|
tensor_info_registry = std::make_unique<TensorInfoRegistry>(model_config);
|
|
layer = std::make_unique<LayerWeightsPtrs>(layer_idx, layer_config,
|
|
*tensor_info_registry);
|
|
|
|
runtime_config.attention_impl = attention_impl;
|
|
runtime_config.kv_cache_type = kv_cache_type;
|
|
inference_args.seq_len = kv_seq_len;
|
|
|
|
all_queries.Reserve(qbatch_size);
|
|
kv_caches.reserve(qbatch_size);
|
|
float unpredictable = hwy::Unpredictable1() * 0.01f;
|
|
for (size_t q = 0; q < qbatch_size; ++q) {
|
|
kv_caches.emplace_back(model_config, inference_args, runtime_config,
|
|
ctx.allocator);
|
|
if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
|
|
const size_t tile_size = gcpp::KVCache::kTileSize;
|
|
gcpp::DecodedTile decoded(qkv_dim, tile_size);
|
|
for (size_t i = 0; i < kv_caches.back().compact_kv_cache_ptr.Rows();
|
|
++i) {
|
|
for (size_t token = 0; token < tile_size; ++token) {
|
|
for (size_t dim = 0; dim < qkv_dim; ++dim) {
|
|
size_t j_k = dim * tile_size + token;
|
|
decoded.k_elem(token, dim) = unpredictable * (i + j_k + 1);
|
|
|
|
size_t j_v = qkv_dim * tile_size + token * qkv_dim + dim;
|
|
decoded.v_elem(token, dim) = unpredictable * (i + j_v + 1);
|
|
}
|
|
}
|
|
|
|
bool transposed =
|
|
attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
|
|
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
|
} // namespace HWY_NAMESPACE
|
|
} // namespace gcpp
|
|
HWY_AFTER_NAMESPACE();
|
|
|
|
#if HWY_ONCE
|
|
|
|
namespace gcpp {
|
|
HWY_BEFORE_TEST(TiledAttentionTest);
|
|
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries);
|
|
// TODO() Fix the goldens for the change in KV_t to BF16
|
|
// HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
|
// TestLocalAttentionForAllHeadsTokensAndBatch);
|
|
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens);
|
|
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16);
|
|
// HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
|
// TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
|
|
|
|
HWY_AFTER_TEST();
|
|
|
|
} // namespace gcpp
|
|
|
|
#endif
|