diff --git a/gemma/gemma.cc b/gemma/gemma.cc index fb874a4..8ac9dce 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -48,6 +48,7 @@ #include "gemma/gemma.h" #include "gemma/weights.h" // Placeholder for internal test1, do not remove +// Placeholder for internal test4, do not remove #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -600,8 +601,6 @@ static void AddFromBatched(size_t num_tokens, const float* other, float* x, } } -// Placeholder for internal test3, do not remove - template HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, const WeightArrayT& weights, @@ -717,8 +716,6 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos, } } - // Placeholder for internal test4, do not remove - RMSNormInplaceBatched(num_tokens, weights.final_norm_scale.data(), activations.x.data(), kModelDim); if (layers_output) { @@ -771,6 +768,8 @@ Activations& GetActivations(const ByteStorageT& state_u8) { } // namespace +// Placeholder for internal test3, do not remove + template void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,