diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index f3767bf..c826d43 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -89,6 +89,7 @@ cc_library( textual_hdrs = ["nuq-inl.h"], deps = [ ":sfp", + "//:allocator", "@hwy//:hwy", "@hwy//hwy/contrib/sort:vqsort", ], diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index f8fa467..faa5ba7 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -22,12 +22,9 @@ #include #include "compression/shared.h" +#include "util/allocator.h" #include "hwy/base.h" -#if HWY_IS_MSAN -#include -#endif - #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ // Actual per-target include guard. @@ -49,15 +46,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -static inline void MaybeCheckInitialized(const void* ptr, size_t size) { -#if HWY_IS_MSAN - __msan_check_mem_is_initialized(ptr, size); -#else - (void)ptr; - (void)size; -#endif -} - // For internal use by NuqCodec. class NuqClustering { static constexpr size_t kGroupSize = NuqStream::kGroupSize; @@ -756,7 +744,6 @@ class NuqCodec { const hn::Half> d8q; using V8Q = hn::Vec; using V16 = hn::Vec; - using VF = hn::Vec; const size_t within_group = packed_ofs % kGroupSize; HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 6a9573d..d2d9cd0 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -909,7 +909,8 @@ HWY_NOINLINE void Prefill( // the last token, too. However, we need to rewind this for the generation // of the first token. So we need to keep track of this. // TODO: consider implementing masking instead of this logic? - bool attend_to_last_token = (prefill_this_query < prefix_end_this_query); + const bool attend_to_last_token = + (prefill_this_query < prefix_end_this_query); if (attend_to_last_token) { // The difference can be at most 1. prefill_this_query += 1; @@ -1219,8 +1220,21 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Copy so we can increment without requiring users to pass in a mutable span. std::vector queries_pos_copy(queries_pos_in.cbegin(), queries_pos_in.cend()); - const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), - queries_pos_copy.size()); + QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), + queries_pos_copy.size()); + // For the first turn, qpos remains 0. Otherwise, rewind the previous EOS. + // Background: for multiturn, Gemma 2 expects only , not EOS. The + // previous `Generate` called `StreamToken` for the last token (EOS), hence + // our caller's qpos is 1 too high. This must be corrected because we didn't + // write to the KV cache at that position, so MSAN would complain. + for (size_t& qpos : queries_mutable_pos) { + qpos = qpos == 0 ? 0 : qpos - 1; + } + // Sanity check: prompts should not be empty, nor start with EOS. + for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { + const PromptTokens& prompt = queries_prompt[query_idx]; + HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id); + } const size_t num_queries = queries_prompt.size(); HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. diff --git a/util/allocator.h b/util/allocator.h index 38ff84e..9e664b5 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -22,8 +22,21 @@ #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#if HWY_IS_MSAN +#include +#endif + namespace gcpp { +static inline void MaybeCheckInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + __msan_check_mem_is_initialized(ptr, size); +#else + (void)ptr; + (void)size; +#endif +} + // Shared between gemma.h and ops-inl.h. struct TokenAndProb { int token;