Merge pull request #130 from veluca93:weight-handling

PiperOrigin-RevId: 622405491
This commit is contained in:
Copybara-Service 2024-04-06 02:22:00 -07:00
commit 325ef06cf9
11 changed files with 424 additions and 270 deletions

View File

@ -29,13 +29,13 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h" #include "compression/blob_store.h"
#include <fcntl.h> // open
#include <stdint.h> #include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE. #include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY #include <sys/stat.h> // O_RDONLY
#include <fcntl.h> // open
#if HWY_OS_WIN #if HWY_OS_WIN
#include <io.h> // read, write, close
#include <fileapi.h> #include <fileapi.h>
#include <io.h> // read, write, close
#else #else
#include <unistd.h> // read, write, close #include <unistd.h> // read, write, close
#endif #endif
@ -113,7 +113,8 @@ hwy::uint128_t MakeKey(const char* string) {
return ret; return ret;
} }
static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, namespace {
void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
std::vector<BlobIO>& requests) { std::vector<BlobIO>& requests) {
// Split into chunks for load-balancing even if blob sizes vary. // Split into chunks for load-balancing even if blob sizes vary.
constexpr size_t kChunkSize = 4 * 1024 * 1024; constexpr size_t kChunkSize = 4 * 1024 * 1024;
@ -129,7 +130,7 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
requests.emplace_back(offset + pos, size - pos, data + pos, 0); requests.emplace_back(offset + pos, size - pos, data + pos, 0);
} }
} }
} // namespace
struct IO { struct IO {
// Returns size in bytes or 0. // Returns size in bytes or 0.
@ -197,12 +198,6 @@ static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
class BlobStore { class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
// Blob offsets on disk and memory addresses are a multiple of this, because
// we pad the header and each blob's size. This matches CUDA alignment and the
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
// 128), which can help performance.
static constexpr size_t kAlign = 256;
public: public:
// NOT including padding, so that we can also use ZeroFillPadding after // NOT including padding, so that we can also use ZeroFillPadding after
// copying the header. // copying the header.
@ -215,13 +210,13 @@ class BlobStore {
// blobs. Requires num_blobs_ to already be set, typically by reading // blobs. Requires num_blobs_ to already be set, typically by reading
// sizeof(BlobStore) bytes from disk. // sizeof(BlobStore) bytes from disk.
size_t PaddedHeaderSize() const { size_t PaddedHeaderSize() const {
return hwy::RoundUpTo(HeaderSize(num_blobs_), kAlign); return hwy::RoundUpTo(HeaderSize(num_blobs_), kBlobAlign);
} }
// Returns aligned offset and zero-fills between that and `offset`. // Returns aligned offset and zero-fills between that and `offset`.
uint64_t ZeroFillPadding(uint64_t offset) { uint64_t ZeroFillPadding(uint64_t offset) {
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this); uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
const uint64_t padded = hwy::RoundUpTo(offset, kAlign); const uint64_t padded = hwy::RoundUpTo(offset, kBlobAlign);
hwy::ZeroBytes(bytes + offset, padded - offset); hwy::ZeroBytes(bytes + offset, padded - offset);
return padded; return padded;
} }
@ -236,7 +231,7 @@ class BlobStore {
for (size_t i = 0; i < num_blobs_; ++i) { for (size_t i = 0; i < num_blobs_; ++i) {
const hwy::uint128_t val = keys_[num_blobs_ + i]; const hwy::uint128_t val = keys_[num_blobs_ + i];
if (val.lo != offset) return __LINE__; if (val.lo != offset) return __LINE__;
offset = ZeroFillPadding(offset + val.hi); offset = hwy::RoundUpTo(offset + val.hi, kBlobAlign);
} }
if (offset != file_size_) return __LINE__; if (offset != file_size_) return __LINE__;
@ -253,25 +248,24 @@ class BlobStore {
static std::vector<BlobIO> PrepareWriteRequests( static std::vector<BlobIO> PrepareWriteRequests(
const hwy::uint128_t keys[], const hwy::Span<uint8_t> blobs[], const hwy::uint128_t keys[], const hwy::Span<uint8_t> blobs[],
size_t num_blobs) { size_t num_blobs, BlobStore* bs) {
// Sanity check and ensure the cast below is safe. // Sanity check and ensure the cast below is safe.
HWY_ASSERT(num_blobs < (1ULL << 20)); HWY_ASSERT(num_blobs < (1ULL << 20));
// Allocate var-length header. // Allocate var-length header.
const size_t header_size = HeaderSize(num_blobs); const size_t header_size = HeaderSize(num_blobs);
const size_t padded_header_size = hwy::RoundUpTo(header_size, kAlign); const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
BlobStorePtr bs = Allocate(padded_header_size);
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size); const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
HWY_ASSERT(padded_header_end == padded_header_size); HWY_ASSERT(padded_header_end == padded_header_size);
// All-zero buffer used to write padding to the file without copying the // All-zero buffer used to write padding to the file without copying the
// input blobs. // input blobs.
static uint8_t zeros[kAlign] = {0}; static uint8_t zeros[kBlobAlign] = {0};
// Total file size will be the header plus all padded blobs. // Total file size will be the header plus all padded blobs.
uint64_t payload = 0; uint64_t payload = 0;
for (size_t i = 0; i < num_blobs; ++i) { for (size_t i = 0; i < num_blobs; ++i) {
payload += hwy::RoundUpTo(blobs[i].size(), kAlign); payload += hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
} }
const size_t total_size = padded_header_size + payload; const size_t total_size = padded_header_size + payload;
@ -285,7 +279,7 @@ class BlobStore {
std::vector<BlobIO> requests; std::vector<BlobIO> requests;
requests.reserve(1 + 2 * num_blobs); requests.reserve(1 + 2 * num_blobs);
requests.emplace_back(/*offset=*/0, padded_header_size, requests.emplace_back(/*offset=*/0, padded_header_size,
reinterpret_cast<uint8_t*>(bs.get()), 0); reinterpret_cast<uint8_t*>(bs), 0);
// Fill second half of keys_ with offset/size and prepare IO requests. // Fill second half of keys_ with offset/size and prepare IO requests.
uint64_t offset = padded_header_end; uint64_t offset = padded_header_end;
@ -295,10 +289,10 @@ class BlobStore {
EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests); EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests);
offset += blobs[i].size(); offset += blobs[i].size();
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kAlign); const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
if (padded_size != blobs[i].size()) { if (padded_size != blobs[i].size()) {
const size_t padding = padded_size - blobs[i].size(); const size_t padding = padded_size - blobs[i].size();
HWY_ASSERT(padding <= kAlign); HWY_ASSERT(padding <= kBlobAlign);
requests.emplace_back(offset, padding, zeros, 0); requests.emplace_back(offset, padding, zeros, 0);
offset += padding; offset += padding;
} }
@ -418,8 +412,11 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
HWY_ASSERT(keys_.size() == blobs_.size()); HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory. // Concatenate blobs in memory.
const size_t header_size = BlobStore::HeaderSize(keys_.size());
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
BlobStorePtr bs = BlobStore::Allocate(padded_header_size);
std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests( std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
keys_.data(), blobs_.data(), keys_.size()); keys_.data(), blobs_.data(), keys_.size(), bs.get());
// Create/replace existing file. // Create/replace existing file.
#if HWY_OS_WIN #if HWY_OS_WIN

View File

@ -40,6 +40,12 @@ using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
// 0 if successful, otherwise the line number of the failing check. // 0 if successful, otherwise the line number of the failing check.
using BlobError = int; using BlobError = int;
// Blob offsets on disk and memory addresses are a multiple of this, because
// we pad the header and each blob's size. This matches CUDA alignment and the
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
// 128), which can help performance.
static constexpr size_t kBlobAlign = 256;
struct BlobIO { struct BlobIO {
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding) BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
: offset(offset), size(size), data(data), padding(padding) {} : offset(offset), size(size), data(data), padding(padding) {}

View File

@ -381,13 +381,14 @@ HWY_INLINE void Compress(const std::array<float, kCapacity>& in,
} }
// Decompresses `num` values from `compressed` starting at `compressed_ofs`. // Decompresses `num` values from `compressed` starting at `compressed_ofs`.
template <typename MatT, size_t kCapacity, typename OutT> template <typename ArrayT, typename OutT>
HWY_NOINLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed, HWY_NOINLINE void Decompress(const ArrayT& compressed, size_t compressed_ofs,
size_t compressed_ofs, OutT* out, size_t num) { OutT* out, size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements()); HWY_DASSERT(compressed_ofs + num <= compressed.size());
const hn::ScalableTag<OutT> d; const hn::ScalableTag<OutT> d;
using Traits = CompressTraits<MatT>; using Traits = CompressTraits<typename ArrayT::value_type>;
Traits::Decompress(d, kCapacity, compressed.data(), compressed_ofs, out, num); Traits::Decompress(d, compressed.size(), compressed.data(), compressed_ofs,
out, num);
} }
// As above, but with threading and benchmarking. // As above, but with threading and benchmarking.
@ -395,7 +396,7 @@ template <typename MatT, size_t kCapacity, typename OutT>
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed, HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, OutT* out, size_t num, size_t compressed_ofs, OutT* out, size_t num,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements()); HWY_DASSERT(compressed_ofs + num <= compressed.size());
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
using Traits = CompressTraits<MatT>; using Traits = CompressTraits<MatT>;
@ -407,7 +408,7 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
const size_t ofs = idx_batch * kBatch; const size_t ofs = idx_batch * kBatch;
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch; const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
Traits::Decompress(d, compressed.NumElements(), compressed.data(), Traits::Decompress(d, compressed.size(), compressed.data(),
compressed_ofs + ofs, out + ofs, num); compressed_ofs + ofs, out + ofs, num);
}); });
@ -417,16 +418,28 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
fprintf(stderr, "Decompress %.1f MB/s\n", mbps); fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
} }
// Returns dot product with `vec_aligned` of length `num`.
template <class DF, typename ArrayT, typename VecT>
HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
const VecT* vec_aligned, size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.size());
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using Traits = CompressTraits<typename ArrayT::value_type>;
return Traits::Dot(df, compressed.size(), compressed.data(), compressed_ofs,
vec_aligned, num);
}
// Returns dot product with `vec_aligned` of length `num`. // Returns dot product with `vec_aligned` of length `num`.
template <class DF, typename MatT, size_t kCapacity, typename VecT> template <class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed, HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, const VecT* vec_aligned, size_t compressed_ofs, const VecT* vec_aligned,
size_t num) { size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements()); HWY_DASSERT(compressed_ofs + num <= compressed.size());
HWY_DASSERT(hn::IsAligned(df, vec_aligned)); HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using Traits = CompressTraits<MatT>; using Traits = CompressTraits<MatT>;
return Traits::Dot(df, kCapacity, compressed.data(), compressed_ofs, return (compressed.scale() * Traits::Dot(df, compressed.size(),
vec_aligned, num); compressed.data(), compressed_ofs,
vec_aligned, num));
} }
// Callback used by ForeachTensor. // Callback used by ForeachTensor.
@ -445,6 +458,12 @@ class Compressor {
compressed.CompressedSize()); compressed.CompressedSize());
} }
void AddScales(float* scales, size_t len) {
if (len) {
writer_.Add(CacheKey<float>("scales"), scales, len * sizeof(scales[0]));
}
}
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) { void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename); const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) { if (err != 0) {

View File

@ -71,10 +71,15 @@ class CompressedArray {
} }
public: public:
using value_type = MatT;
MatT* data() { return data_.data(); } MatT* data() { return data_.data(); }
const MatT* data() const { return data_.data(); } const MatT* data() const { return data_.data(); }
constexpr size_t NumElements() const { return kCapacity; } float scale() const { return scale_[0]; }
void set_scale(float scale) { scale_[0] = scale; }
constexpr size_t size() const { return kCapacity; }
constexpr size_t CompressedSize() const { constexpr size_t CompressedSize() const {
return NumCompressed() * sizeof(MatT); return NumCompressed() * sizeof(MatT);
@ -82,6 +87,7 @@ class CompressedArray {
private: private:
std::array<MatT, NumCompressed()> data_; std::array<MatT, NumCompressed()> data_;
float scale_[kBlobAlign / sizeof(float)];
}; };
#if COMPRESS_STATS #if COMPRESS_STATS
@ -187,11 +193,21 @@ class CacheLoader {
err_ = reader_.Enqueue(CacheKey<MatT>(name), compressed.data(), err_ = reader_.Enqueue(CacheKey<MatT>(name), compressed.data(),
compressed.CompressedSize()); compressed.CompressedSize());
compressed.set_scale(1.0f);
if (err_ != 0) { if (err_ != 0) {
fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_); fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_);
} }
} }
void LoadScales(float* scales, size_t len) {
if (0 != reader_.Enqueue(CacheKey<float>("scales"), scales,
len * sizeof(scales[0]))) {
for (size_t i = 0; i < len; ++i) {
scales[i] = 1.0f;
}
}
}
// Returns whether all tensors are successfully loaded from cache. // Returns whether all tensors are successfully loaded from cache.
bool ReadAll(hwy::ThreadPool& pool) { bool ReadAll(hwy::ThreadPool& pool) {
// reader_ invalid or any Enqueue failed // reader_ invalid or any Enqueue failed

View File

@ -30,6 +30,16 @@
#include <stddef.h> #include <stddef.h>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
namespace gcpp { namespace gcpp {
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
@ -45,6 +55,8 @@ struct ConfigGemma7B {
static constexpr int kKVHeads = 16; // standard MHA static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
}; };
struct ConfigGemma2B { struct ConfigGemma2B {
@ -57,6 +69,8 @@ struct ConfigGemma2B {
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -19,9 +19,9 @@
#include "gemma.h" #include "gemma.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/app.h" // LoaderArgs #include "util/app.h" // LoaderArgs
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
std::vector<int> tokenize( std::vector<int> tokenize(
const std::string& prompt_string, const std::string& prompt_string,
@ -43,8 +43,7 @@ int main(int argc, char** argv) {
hwy::ThreadPool pool(num_threads); hwy::ThreadPool pool(num_threads);
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType()); auto kv_cache = CreateKVCache(loader.ModelType());
size_t pos = 0; // KV Cache position size_t pos = 0; // KV Cache position

465
gemma.cc
View File

@ -25,12 +25,12 @@
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first. // compile pass, whereas we want this defined in the first.
@ -64,6 +64,12 @@
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
// Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true;
namespace gcpp { namespace gcpp {
template <class TConfig> template <class TConfig>
@ -88,70 +94,145 @@ struct Layer {
std::array<float, kModelDim> pre_ffw_norm_scale; std::array<float, kModelDim> pre_ffw_norm_scale;
}; };
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
// Array instead of single large allocation for parallel mem init. Split out of
// Weights so that only these pointers are initialized.
template <class TConfig>
struct LayerPointers {
explicit LayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->layers[task] = hwy::AllocateAligned<Layer<TConfig>>(1);
});
}
using TLayer = Layer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
};
template <class TConfig> template <class TConfig>
struct Weights { struct Weights {
Weights() = default; // No ctor/dtor, allocated via AllocateAligned.
hwy::AlignedUniquePtr<Layer<TConfig>[]> layers; // kLayers
std::array<float, TConfig::kVocabSize * TConfig::kModelDim> std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding; embedder_input_embedding;
std::array<float, TConfig::kModelDim> final_norm_scale; std::array<float, TConfig::kModelDim> final_norm_scale;
LayerPointers<TConfig> layer_ptrs;
std::array<float, TConfig::kNumTensorScales> scales;
const Layer<TConfig>* GetLayer(size_t layer) const {
return layer_ptrs.layers[layer].get();
}
Layer<TConfig>* GetLayer(size_t layer) {
return layer_ptrs.layers[layer].get();
}
}; };
// Only called if cached loading fails.
template <typename TConfig> template <typename TConfig>
hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) { hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression = false) {
PROFILER_ZONE("Startup.LoadWeights"); PROFILER_ZONE("Startup.LoadWeights");
using TWeights = Weights<TConfig>; if (!std::filesystem::exists(checkpoint.path)) {
hwy::AlignedUniquePtr<TWeights> weights = hwy::MakeUniqueAligned<TWeights>(); HWY_ABORT("The model weights file '%s' does not exist.",
weights->layers = checkpoint.path.c_str());
hwy::MakeUniqueAlignedArray<Layer<TConfig>>(TConfig::kLayers);
if (checkpoint.path.empty()) {
HWY_ABORT(
"Loading --compressed_weights failed; we require a --weights argument. "
"Please see issue #11 on how to create this file.\n");
} }
using TWeights = Weights<TConfig>;
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<TConfig>(pool);
size_t scale_pos = 0;
FILE* fptr; FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb"); fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) { if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?", HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str()); checkpoint.path.c_str());
} }
}
bool ok = true; bool ok = true;
uint64_t total_size = 0; uint64_t total_size = 0;
ok &= 1 == fread(&(weights->embedder_input_embedding), auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
sizeof(weights->embedder_input_embedding), 1, fptr); if (layer == -1) {
ok &= 1 == fread(&(weights->final_norm_scale), fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
sizeof(weights->final_norm_scale), 1, fptr); } else {
total_size += sizeof(weights->embedder_input_embedding) + fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
sizeof(weights->final_norm_scale); size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1, "embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
Layer<TConfig>* layer_view = &weights->layers[layer]; Layer<TConfig>* layer_view = weights->GetLayer(layer);
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
sizeof(layer_view->attn_vec_einsum_w), 1, fptr); #define READ_WEIGHTS(name) \
ok &= 1 == fread(&layer_view->qkv_einsum_w, do { \
sizeof(layer_view->qkv_einsum_w), 1, fptr); do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
ok &= 1 == fread(&layer_view->gating_einsum_w, } while (0)
sizeof(layer_view->gating_einsum_w), 1, fptr);
ok &= 1 == #define SCALE_WEIGHTS(name) \
fread(&layer_view->linear_w, sizeof(layer_view->linear_w), 1, fptr); do { \
ok &= 1 == fread(&layer_view->pre_attention_norm_scale, if (ok && !kDryRunFread && scale_for_compression) { \
sizeof(layer_view->pre_attention_norm_scale), 1, fptr); weights->scales[scale_pos++] = \
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale, ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr); } \
total_size += sizeof(*layer_view); } while (0)
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
#undef READ_WEIGHTS
} }
if (!ok) { if (!ok) {
HWY_ABORT("Failed to read from %s - might be a directory, or too small? " HWY_ABORT(
"expected size: %d kB", checkpoint.path.c_str(), "Failed to read from %s - might be a directory, or too small? "
static_cast<uint32_t>(total_size >> 10)); "expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
} }
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr)); HWY_ASSERT(0 == fclose(fptr));
return weights; if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
} }
template <class TConfig> template <class TConfig>
@ -159,18 +240,19 @@ struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::Layer<TConfig>; using TLayer = gcpp::Layer<TConfig>;
using WeightT = typename TConfig::WeightT;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
// Compressed Parameters // Compressed Parameters
// We don't yet have an RMSNorm that accepts all WeightT. // We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, kModelDim> c_pre_attention_norm_scale; CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, kModelDim> c_pre_ffw_norm_scale; CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> c_gating_einsum_w; CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
CompressedArray<WeightT, kModelDim * kFFHiddenDim> c_linear_w; CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> c_qkv_einsum_w; CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> c_attn_vec_einsum_w; CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
}; };
// Array instead of single large allocation for parallel mem init. Split out of // Array instead of single large allocation for parallel mem init. Split out of
@ -193,21 +275,25 @@ struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim> CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
c_embedder_input_embedding; embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> c_final_norm_scale; CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
// Must be last so that the other arrays remain aligned. // Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs; CompressedLayerPointers<TConfig> c_layer_ptrs;
const CompressedLayer<TConfig>* CLayer(size_t layer) const { const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
return c_layer_ptrs.c_layers[layer].get(); return c_layer_ptrs.c_layers[layer].get();
} }
CompressedLayer<TConfig>* CLayer(size_t layer) { CompressedLayer<TConfig>* GetLayer(size_t layer) {
return c_layer_ptrs.c_layers[layer].get(); return c_layer_ptrs.c_layers[layer].get();
} }
}; };
template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
Weights<TConfig>>;
// Aligned. // Aligned.
template <class TConfig, size_t TBatchSize> template <class TConfig, size_t TBatchSize>
struct Activations { struct Activations {
@ -272,16 +358,27 @@ KVCache CreateKVCache(Model type) {
} }
} }
namespace {
template <class Config>
void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}
template <class Config>
void DeleteLayersPtrs(Weights<Config>* weights) {
weights->layer_ptrs.~LayerPointers<Config>();
}
} // namespace
template <class Config> template <class Config>
struct GemmaImpl : public GemmaInterface { struct GemmaImpl : public GemmaInterface {
GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer, GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights, hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
~GemmaImpl() { ~GemmaImpl() {
using CWeights = CompressedWeights<Config>; WeightsT<Config>* weights =
CWeights* c_weights = reinterpret_cast<CWeights*>(compressed_weights.get()); reinterpret_cast<WeightsT<Config>*>(weights_u8.get());
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>(); DeleteLayersPtrs(weights);
} }
const sentencepiece::SentencePieceProcessor* Tokenizer() const override { const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
@ -296,7 +393,7 @@ struct GemmaImpl : public GemmaInterface {
int verbosity) override; int verbosity) override;
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights; hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill; hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state; hwy::AlignedUniquePtr<Activations<Config, 1>> state;
}; };
@ -309,11 +406,11 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
template <class TConfig, size_t kBatchSize> template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
const CompressedLayer<TConfig>* c_layer, const LayerT* layer_weights, KVCache& kv_cache,
KVCache& kv_cache, hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention"); PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
HWY_DASSERT(batch_idx < kBatchSize); HWY_DASSERT(batch_idx < kBatchSize);
@ -329,26 +426,25 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
static const float kQueryScale = static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
const size_t batch_offset = batch_idx * kModelDim; float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
MatVecLoop<kQKVDim, kModelDim>( MatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim, head_offset + 0 * kQKVDim * kModelDim, x, q);
activations.pre_att_rms_out.data() + batch_offset, q);
}; };
auto ProjKV = auto ProjKV = [&](size_t k_offset, size_t v_offset,
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { size_t kv_offset) HWY_ATTR {
TwoOfsMatVecLoop<kQKVDim, kModelDim>( float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset;
c_layer->c_qkv_einsum_w, k_offset, v_offset, float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset;
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
v_offset, x, k, v);
Rope(k, kQKVDim, pos);
}; };
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
@ -388,7 +484,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
head == 0 head == 0
? activations.att_post2.data() + batch_idx * kModelDim ? activations.att_post2.data() + batch_idx * kModelDim
: activations.att_post1.data() + head * kBatchSize * kModelDim; : activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w, MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out, head * kModelDim * kQKVDim, att_out,
head_out); head_out);
}; };
@ -431,9 +527,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
} }
} }
template <typename TConfig, size_t kBatchSize> template <size_t kBatchSize, typename LayerT, typename TConfig>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations, HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t batch_idx, const CompressedLayer<TConfig>* c_layer, size_t batch_idx, const LayerT* layer_weights,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < kBatchSize); HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
@ -449,12 +545,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// Same matrix, first and second half of rows. Could fuse into one MatVec, // Same matrix, first and second half of rows. Could fuse into one MatVec,
// but separating them could help on NUMA e.g. multiple sockets. // but separating them could help on NUMA e.g. multiple sockets.
MatVec<kFFHiddenDim, kModelDim>(c_layer->c_gating_einsum_w, MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w,
kFFHiddenDim * kModelDim, vec, out_mul, kFFHiddenDim * kModelDim, vec, out_mul,
pool); pool);
// Gate, will go through the nonlinearity. // Gate, will go through the nonlinearity.
MatVec<kFFHiddenDim, kModelDim>(c_layer->c_gating_einsum_w, 0, vec, out, MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, 0, vec, out,
pool); pool);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -467,7 +563,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
PROFILER_ZONE("Gen.FFW\\GatedGELU"); PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVec<kModelDim, kFFHiddenDim>( MatVec<kModelDim, kFFHiddenDim>(
c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset, layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
activations.ffw_out.data() + batch_idx * kModelDim, pool); activations.ffw_out.data() + batch_idx * kModelDim, pool);
} }
@ -486,9 +582,9 @@ GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
Sqrt(static_cast<float>(TConfig::kModelDim)))); Sqrt(static_cast<float>(TConfig::kModelDim))));
} }
template <typename TConfig, size_t kBatchSize> template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const CompressedWeights<TConfig>& c_weights, const WeightArrayT& weights,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool) { hwy::ThreadPool& inner_pool) {
@ -500,22 +596,22 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
pool.Run( pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
const int token = tokens[token_idx]; const int token = tokens[token_idx];
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
kModelDim); kModelDim);
}); });
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations.x.data() + token_idx * kModelDim, RMSNorm(activations.x.data() + token_idx * kModelDim,
c_layer->c_pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim, activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
Attention<TConfig, kBatchSize>(pos, token_idx, layer, activations, Attention<kBatchSize>(pos, token_idx, layer, activations, layer_weights,
c_layer, kv_cache, pool); kv_cache, pool);
} }
// TODO: sink the loop into these functions, i.e. make them matmuls. // TODO: sink the loop into these functions, i.e. make them matmuls.
@ -525,10 +621,10 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
AddFrom(activations.att_post2.data() + token_idx * kModelDim, AddFrom(activations.att_post2.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
RMSNorm(activations.x.data() + token_idx * kModelDim, RMSNorm(activations.x.data() + token_idx * kModelDim,
c_layer->c_pre_ffw_norm_scale.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
FFW<TConfig, kBatchSize>(activations, token_idx, c_layer, inner_pool); FFW<kBatchSize>(activations, token_idx, layer_weights, inner_pool);
AddFrom(activations.ffw_out.data() + token_idx * kModelDim, AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
}); });
@ -536,21 +632,20 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
pool.Run( pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
RMSNormInplace(c_weights.c_final_norm_scale.data(), RMSNormInplace(weights.final_norm_scale.data(),
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
}); });
} }
// n = 1 specialization // n = 1 specialization
template <class TConfig> template <typename WeightArrayT, class TConfig>
void Transformer(int token, size_t pos, void Transformer(int token, size_t pos, const WeightArrayT& weights,
const CompressedWeights<TConfig>& c_weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache, Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
@ -558,17 +653,18 @@ void Transformer(int token, size_t pos,
MulByConst(kEmbScaling, activations.x.data(), kModelDim); MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) { for (size_t layer = 0; layer < kLayers; ++layer) {
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(), RMSNorm(activations.x.data(),
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim); activations.pre_att_rms_out.data(), kModelDim);
Attention<TConfig, 1>(pos, 0, layer, activations, c_layer, kv_cache, pool); Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool);
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(), RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<TConfig, 1>(activations, /* batch_idx = */ 0, c_layer, pool); FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
} }
RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(), RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
kModelDim); kModelDim);
} }
@ -609,9 +705,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Activations<TConfig, 1>& activations = *gemma.state.get(); Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations = Activations<TConfig, kPrefillBatchSize>& prefill_activations =
*gemma.prefill.get(); *gemma.prefill.get();
const CompressedWeights<TConfig>& c_weights =
*reinterpret_cast<CompressedWeights<TConfig>*>( const WeightsT<TConfig>& weights =
gemma.compressed_weights.get()); *reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
size_t prompt_size = prompt.size(); size_t prompt_size = prompt.size();
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size); RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
@ -643,9 +739,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
HWY_DASSERT(batch_size <= kPrefillBatchSize); HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset; const int* batch_tokens = prompt.data() + pos_offset;
Prefill<TConfig, kPrefillBatchSize>(batch_tokens, batch_size, pos, Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
c_weights, prefill_activations, prefill_activations, kv_cache, pool, inner_pool);
kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) { for (size_t idx = 0; idx < batch_size; ++idx) {
stream_token(batch_tokens[idx], 0.0f); stream_token(batch_tokens[idx], 0.0f);
} }
@ -672,7 +767,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
for (size_t generate_pos = 0; for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above. // The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill // We keep it here for clarity so that the code is correct even if Prefill
@ -680,8 +775,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
if (pos_offset >= prompt_size - 1) { if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Generation phase // Generation phase
MatVec<kVocabSize, TConfig::kModelDim>( MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
c_weights.c_embedder_input_embedding, 0, final_activation, 0, final_activation,
activations.logits.data(), pool); activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
@ -743,52 +838,37 @@ void ForEachTensor(const Weights<TConfig>* weights,
CompressedWeights<TConfig>& c_weights, Func& func) { CompressedWeights<TConfig>& c_weights, Func& func) {
func("c_embedding", func("c_embedding",
weights ? weights->embedder_input_embedding.data() : nullptr, weights ? weights->embedder_input_embedding.data() : nullptr,
c_weights.c_embedder_input_embedding); c_weights.embedder_input_embedding);
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr, func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
c_weights.c_final_norm_scale); c_weights.final_norm_scale);
char name[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < static_cast<int>(TConfig::kLayers); for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
++layer_idx) {
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
Layer<TConfig>* layer = weights ? &weights->layers[idx] : nullptr; const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(idx); CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
snprintf(name, sizeof(name), "pre_ff_ns_%d", layer_idx); #define CALL_FUNC(name, member) \
func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr, snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
c_layer->c_pre_ffw_norm_scale); func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
snprintf(name, sizeof(name), "gating_ein_%d", layer_idx); CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
func(name, layer ? layer->gating_einsum_w.data() : nullptr, CALL_FUNC("gating_ein", gating_einsum_w);
c_layer->c_gating_einsum_w); CALL_FUNC("linear_w", linear_w);
CALL_FUNC("qkv_ein", qkv_einsum_w);
snprintf(name, sizeof(name), "linear_w_%d", layer_idx); CALL_FUNC("att_ein", attn_vec_einsum_w);
func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w); CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
snprintf(name, sizeof(name), "qkv_ein_%d", layer_idx); #undef CALL_FUNC
func(name, layer ? layer->qkv_einsum_w.data() : nullptr,
c_layer->c_qkv_einsum_w);
snprintf(name, sizeof(name), "att_ein_%d", layer_idx);
func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr,
c_layer->c_attn_vec_einsum_w);
snprintf(name, sizeof(name), "pre_att_ns_%d", layer_idx);
func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr,
c_layer->c_pre_attention_norm_scale);
} }
} }
template <class TConfig> template <class TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights( hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) { const Path& weights, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.LoadCache"); PROFILER_ZONE("Startup.LoadCache");
if (!std::filesystem::exists(weights.path)) {
if (!std::filesystem::exists(weights_path.path) && HWY_ABORT("The model weights file '%s' does not exist.",
!std::filesystem::exists(cache.path)) { weights.path.c_str());
HWY_ABORT(
"Either the model weights (--weights) or cached compressed weights "
"(--compressed_weights) must exist.");
} }
// Allocate compressed weights. // Allocate compressed weights.
@ -798,32 +878,49 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get()); CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool); new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
// First attempt to load them from cache, without requiring weights. std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(cache.path.c_str()); CacheLoader loader(weights.path.c_str());
ForEachTensor<TConfig>(nullptr, *c_weights, loader); ForEachTensor<TConfig>(nullptr, *c_weights, loader);
if (loader.ReadAll(pool)) return c_weights_u8; loader.LoadScales(scales.data(), scales.size());
if (!loader.ReadAll(pool)) {
// Get weights, compress, and store in cache. HWY_ABORT("Failed to load model weights.");
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = }
LoadWeights<TConfig>(weights_path); if (TConfig::kNumTensorScales > 0) {
Compressor compressor(pool); size_t scale_pos = 0;
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor); for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
compressor.WriteAll(pool, cache.path.c_str()); const size_t idx = static_cast<size_t>(layer_idx);
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->linear_w.set_scale(scales[scale_pos++]);
}
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
return c_weights_u8; return c_weights_u8;
} }
// Type-erased because this function is called via a function pointer. // Type-erased because this function is called via a function pointer.
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT( hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
gcpp::Model model, const Path& weights, const Path& compressed_weights, gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
const Path& weights,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
switch (model) { switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
return GetCompressedWeights<ConfigGemma2B>(weights, compressed_weights, return LoadWeights<ConfigGemma2B>(weights, pool);
pool);
case Model::GEMMA_7B: case Model::GEMMA_7B:
return GetCompressedWeights<ConfigGemma7B>(weights, compressed_weights, return LoadWeights<ConfigGemma7B>(weights, pool);
pool);
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }
@ -846,18 +943,22 @@ void CompressWeights(const Path& weights_path,
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool); new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
// Get weights, compress, and store. // Get weights, compress, and store.
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = const bool scale_for_compression = TConfig::kNumTensorScales > 0;
LoadWeights<TConfig>(weights_path); const hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
LoadWeights<TConfig>(weights_path, pool, scale_for_compression);
Weights<TConfig>* weights =
reinterpret_cast<Weights<TConfig>*>(weights_u8.get());
Compressor compressor(pool); Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor); ForEachTensor<TConfig>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path.path.c_str()); compressor.WriteAll(pool, compressed_weights_path.path.c_str());
weights->layer_ptrs.~LayerPointers<TConfig>();
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
} }
void CompressWeightsT(gcpp::Model model, const Path& weights, void CompressWeightsT(gcpp::Model model, const Path& weights,
const Path& compressed_weights, const Path& compressed_weights, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool) {
switch (model) { switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool); CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
@ -877,7 +978,8 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE #if HWY_ONCE
namespace gcpp { namespace gcpp {
HWY_EXPORT(GetCompressedWeightsT); HWY_EXPORT(LoadCompressedWeightsT);
HWY_EXPORT(LoadWeightsT);
HWY_EXPORT(CompressWeightsT); HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(Generate2B); HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B); HWY_EXPORT(Generate7B);
@ -892,10 +994,9 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
template <class Config> template <class Config>
GemmaImpl<Config>::GemmaImpl( GemmaImpl<Config>::GemmaImpl(
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer, std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights, hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8, hwy::ThreadPool& pool)
hwy::ThreadPool& pool)
: tokenizer(std::move(tokenizer)), : tokenizer(std::move(tokenizer)),
compressed_weights(std::move(compressed_weights)), weights_u8(std::move(weights_u8)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()), prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {} state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
@ -922,10 +1023,8 @@ void GemmaImpl<ConfigGemma7B>::Generate(
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
const Path& weights_path, Model model_type, ModelTraining training, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool)
: model_training(training) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{ {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");
@ -934,16 +1033,21 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
HWY_ABORT("Failed to load the tokenizer file."); HWY_ABORT("Failed to load the tokenizer file.");
} }
} }
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
model_type, weights_path, compressed_weights_path, pool); hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
if constexpr (kWeightsAreCompressed) {
weights_u8 =
HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool);
} else {
weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool);
}
switch (model_type) { switch (model_type) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
impl_.reset( impl_.reset(new GemmaImpl<ConfigGemma2B>(tokenizer, weights_u8, pool));
new GemmaImpl<ConfigGemma2B>(tokenizer, compressed_weights, pool));
break; break;
case Model::GEMMA_7B: case Model::GEMMA_7B:
impl_.reset( impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
new GemmaImpl<ConfigGemma7B>(tokenizer, compressed_weights, pool));
break; break;
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
@ -981,10 +1085,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
} }
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, const Path& compressed_weights, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool) { HWY_DYNAMIC_DISPATCH(CompressWeightsT)
HWY_DYNAMIC_DISPATCH(CompressWeightsT)( (model, weights, compressed_weights, pool);
model, weights, compressed_weights, pool);
} }
} // namespace gcpp } // namespace gcpp

19
gemma.h
View File

@ -24,22 +24,18 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream #include "compression/compress.h" // SfpStream/NuqStream
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path #include "configs.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
namespace gcpp { namespace gcpp {
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time): using GemmaWeightT = GEMMA_WEIGHT_T;
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
using WeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t; using EmbedderInputT = hwy::bfloat16_t;
constexpr size_t kPrefillBatchSize = 16; constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false; constexpr bool kSystemPrompt = false;
@ -65,13 +61,11 @@ struct RuntimeConfig {
struct GemmaInterface; struct GemmaInterface;
struct Gemma { struct Gemma {
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
const Path& weights_path, Model model_type, ModelTraining training,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
}; };
KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(Model type); // convenient workaround for now
@ -99,8 +93,7 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const StreamFunc& stream_token, std::mt19937& gen); const StreamFunc& stream_token, std::mt19937& gen);
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, const Path& compressed_weights, hwy::ThreadPool& pool);
hwy::ThreadPool& pool);
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;

View File

@ -369,6 +369,7 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
} }
} }
Compress(content, ws, mat, pool); Compress(content, ws, mat, pool);
mat.set_scale(1.0f);
return mat; return mat;
} }

23
run.cc
View File

@ -29,14 +29,14 @@
#include "gemma.h" // Gemma #include "gemma.h" // Gemma
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/app.h" #include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/per_target.h" #include "hwy/per_target.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
namespace gcpp { namespace gcpp {
@ -66,7 +66,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
<< hwy::VectorBytes() * 8 << " bits)" << "\n" << hwy::VectorBytes() * 8 << " bits)" << "\n"
<< "Compiled config : " << CompiledConfig() << "\n" << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : " << "Weight Type : "
<< gcpp::TypeName(gcpp::WeightT()) << "\n" << gcpp::TypeName(gcpp::GemmaWeightT()) << "\n"
<< "EmbedderInput Type : " << "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
} }
@ -93,10 +93,11 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
std::cerr << "\n"; std::cerr << "\n";
} }
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, void ReplGemma(gcpp::Gemma& model, ModelTraining training,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity, hwy::ThreadPool& inner_pool, const InferenceArgs& args,
const gcpp::AcceptFunc& accept_token, std::string& eot_line) { int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -177,7 +178,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
continue; continue;
} }
if (model.model_training == ModelTraining::GEMMA_IT) { if (training == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens. // For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string + prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n"; "<end_of_turn>\n<start_of_turn>model\n";
@ -232,8 +233,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
} }
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
loader.ModelType(), loader.ModelTraining(), pool);
auto kv_cache = CreateKVCache(loader.ModelType()); auto kv_cache = CreateKVCache(loader.ModelType());
@ -265,7 +265,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma( ReplGemma(
model, kv_cache, pool, inner_pool, inference, app.verbosity, model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference,
app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }

View File

@ -36,9 +36,9 @@
#include "configs.h" #include "configs.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" #include "gemma.h"
#include "hwy/base.h" // HWY_ASSERT
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp { namespace gcpp {
@ -151,7 +151,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() const { const char* Validate() {
const std::string model_type_lc = ToLower(model_type); const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) { if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
@ -165,37 +165,42 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (tokenizer.path.empty()) { if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required."; return "Missing --tokenizer flag, a file for the tokenizer is required.";
} }
if (compressed_weights.path.empty()) { if (!compressed_weights.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed " if (weights.path.empty()) {
"model."; weights = compressed_weights;
} else {
return "Only one of --weights and --compressed_weights can be "
"specified. To create compressed weights use the compress_weights "
"tool.";
}
}
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
if (!weights.exists()) {
return "Can't open file specified with --weights flag.";
} }
return nullptr; return nullptr;
} }
Path tokenizer; Path tokenizer;
Path weights; // uncompressed weights file location Path weights; // weights file location
Path compressed_weights; // compressed weights file location Path compressed_weights;
std::string model_type; std::string model_type;
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(), visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.\n Required argument."); "Path name of tokenizer model file.\n Required argument.");
visitor( visitor(weights, "weights", Path(),
compressed_weights, "compressed_weights", Path(), "Path name of model weights (.sbs) file.\n Required argument.");
"Path name of compressed weights file, regenerated from `--weights` " visitor(compressed_weights, "compressed_weights", Path(),
"file if " "Alias for --weights.");
"the compressed weights file does not exist.\n Required argument.");
visitor(model_type, "model", std::string(), visitor(model_type, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n " "Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
" Required argument."); " Required argument.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be "
"regenerated. This parameter is only required for compressing "
"new model weight exports, otherwise it is not needed.");
} }
}; };