Internal experiment

PiperOrigin-RevId: 641915024
This commit is contained in:
Jan Wassenberg 2024-06-10 08:45:30 -07:00 committed by Copybara-Service
parent 95fd7263ae
commit c1c6714ad4
2 changed files with 12 additions and 0 deletions

View File

@ -93,6 +93,10 @@ cc_library(
"gemma/activations.h", "gemma/activations.h",
"gemma/gemma.h", "gemma/gemma.h",
], ],
textual_hdrs = [
# Placeholder for internal file1, do not remove,
# Placeholder for internal file2, do not remove,
],
deps = [ deps = [
":common", ":common",
":ops", ":ops",

View File

@ -49,6 +49,7 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
// Placeholder for internal test1, do not remove
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -216,6 +217,8 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized); return impl_->Decode(ids, detokenized);
} }
// Placeholder for internal test2, do not remove
} // namespace gcpp } // namespace gcpp
#endif // GEMMA_ONCE #endif // GEMMA_ONCE
@ -543,6 +546,8 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
} }
} }
// Placeholder for internal test3, do not remove
template <size_t kBatchSize, typename WeightArrayT, typename TConfig> template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights, const WeightArrayT& weights,
@ -677,6 +682,9 @@ HWY_NOINLINE void Transformer(int token, size_t pos,
(*layers_output)(pos, block_name, activations.x.data(), kModelDim); (*layers_output)(pos, block_name, activations.x.data(), kModelDim);
} }
} }
// Placeholder for internal test4, do not remove
RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(), RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
kModelDim); kModelDim);
if (layers_output != nullptr) { if (layers_output != nullptr) {