Fix MSAN issue for multiturn. Rewind the prior EOS token.

Also move MaybeCheckInitialized to allocator.h

PiperOrigin-RevId: 683187458
This commit is contained in:
Jan Wassenberg 2024-10-07 08:07:15 -07:00 committed by Copybara-Service
parent 5a71d819cb
commit bd53b0f7c3
4 changed files with 32 additions and 17 deletions

View File

@ -89,6 +89,7 @@ cc_library(
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
"//:allocator",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],

View File

@ -22,12 +22,9 @@
#include <stdio.h>
#include "compression/shared.h"
#include "util/allocator.h"
#include "hwy/base.h"
#if HWY_IS_MSAN
#include <sanitizer/msan_interface.h>
#endif
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
// Actual per-target include guard.
@ -49,15 +46,6 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
__msan_check_mem_is_initialized(ptr, size);
#else
(void)ptr;
(void)size;
#endif
}
// For internal use by NuqCodec.
class NuqClustering {
static constexpr size_t kGroupSize = NuqStream::kGroupSize;
@ -756,7 +744,6 @@ class NuqCodec {
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
using V8Q = hn::Vec<decltype(d8q)>;
using V16 = hn::Vec<decltype(d16)>;
using VF = hn::Vec<decltype(df)>;
const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);

View File

@ -909,7 +909,8 @@ HWY_NOINLINE void Prefill(
// the last token, too. However, we need to rewind this for the generation
// of the first token. So we need to keep track of this.
// TODO: consider implementing masking instead of this logic?
bool attend_to_last_token = (prefill_this_query < prefix_end_this_query);
const bool attend_to_last_token =
(prefill_this_query < prefix_end_this_query);
if (attend_to_last_token) {
// The difference can be at most 1.
prefill_this_query += 1;
@ -1219,8 +1220,21 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend());
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());
// For the first turn, qpos remains 0. Otherwise, rewind the previous EOS.
// Background: for multiturn, Gemma 2 expects only <end_of_turn>, not EOS. The
// previous `Generate` called `StreamToken` for the last token (EOS), hence
// our caller's qpos is 1 too high. This must be corrected because we didn't
// write to the KV cache at that position, so MSAN would complain.
for (size_t& qpos : queries_mutable_pos) {
qpos = qpos == 0 ? 0 : qpos - 1;
}
// Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx];
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
}
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.

View File

@ -22,8 +22,21 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#if HWY_IS_MSAN
#include <sanitizer/msan_interface.h>
#endif
namespace gcpp {
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
__msan_check_mem_is_initialized(ptr, size);
#else
(void)ptr;
(void)size;
#endif
}
// Shared between gemma.h and ops-inl.h.
struct TokenAndProb {
int token;