mirror of https://github.com/google/gemma.cpp.git
Fix MSAN issue for multiturn. Rewind the prior EOS token.
Also move MaybeCheckInitialized to allocator.h PiperOrigin-RevId: 683187458
This commit is contained in:
parent
5a71d819cb
commit
bd53b0f7c3
|
|
@ -89,6 +89,7 @@ cc_library(
|
|||
textual_hdrs = ["nuq-inl.h"],
|
||||
deps = [
|
||||
":sfp",
|
||||
"//:allocator",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -22,12 +22,9 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#if HWY_IS_MSAN
|
||||
#include <sanitizer/msan_interface.h>
|
||||
#endif
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||
|
||||
// Actual per-target include guard.
|
||||
|
|
@ -49,15 +46,6 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||
#if HWY_IS_MSAN
|
||||
__msan_check_mem_is_initialized(ptr, size);
|
||||
#else
|
||||
(void)ptr;
|
||||
(void)size;
|
||||
#endif
|
||||
}
|
||||
|
||||
// For internal use by NuqCodec.
|
||||
class NuqClustering {
|
||||
static constexpr size_t kGroupSize = NuqStream::kGroupSize;
|
||||
|
|
@ -756,7 +744,6 @@ class NuqCodec {
|
|||
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
|
||||
using V8Q = hn::Vec<decltype(d8q)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
||||
const size_t within_group = packed_ofs % kGroupSize;
|
||||
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
|
||||
|
|
|
|||
|
|
@ -909,7 +909,8 @@ HWY_NOINLINE void Prefill(
|
|||
// the last token, too. However, we need to rewind this for the generation
|
||||
// of the first token. So we need to keep track of this.
|
||||
// TODO: consider implementing masking instead of this logic?
|
||||
bool attend_to_last_token = (prefill_this_query < prefix_end_this_query);
|
||||
const bool attend_to_last_token =
|
||||
(prefill_this_query < prefix_end_this_query);
|
||||
if (attend_to_last_token) {
|
||||
// The difference can be at most 1.
|
||||
prefill_this_query += 1;
|
||||
|
|
@ -1219,8 +1220,21 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||
queries_pos_in.cend());
|
||||
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||
queries_pos_copy.size());
|
||||
QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||
queries_pos_copy.size());
|
||||
// For the first turn, qpos remains 0. Otherwise, rewind the previous EOS.
|
||||
// Background: for multiturn, Gemma 2 expects only <end_of_turn>, not EOS. The
|
||||
// previous `Generate` called `StreamToken` for the last token (EOS), hence
|
||||
// our caller's qpos is 1 too high. This must be corrected because we didn't
|
||||
// write to the KV cache at that position, so MSAN would complain.
|
||||
for (size_t& qpos : queries_mutable_pos) {
|
||||
qpos = qpos == 0 ? 0 : qpos - 1;
|
||||
}
|
||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
|
||||
}
|
||||
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||
|
|
|
|||
|
|
@ -22,8 +22,21 @@
|
|||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#if HWY_IS_MSAN
|
||||
#include <sanitizer/msan_interface.h>
|
||||
#endif
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||
#if HWY_IS_MSAN
|
||||
__msan_check_mem_is_initialized(ptr, size);
|
||||
#else
|
||||
(void)ptr;
|
||||
(void)size;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Shared between gemma.h and ops-inl.h.
|
||||
struct TokenAndProb {
|
||||
int token;
|
||||
|
|
|
|||
Loading…
Reference in New Issue