mirror of https://github.com/google/gemma.cpp.git
Improve weight handling.
- Allow scaling of SFP weights - Allow using uncompressed weights - Do not try to compress weights in the main model calls - Reduce code duplication in weight handling with some macros Co-authored-by: Eugene Kliuchnikov <eustas@google.com> Co-authored-by: Thomas Fischbacher <tfish@google.com> Co-authored-by: Zoltan Szabadka <szabadka@google.com>
This commit is contained in:
parent
280b8cb8a1
commit
4c23932289
|
|
@ -29,13 +29,13 @@
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
|
|
||||||
|
#include <fcntl.h> // open
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||||
#include <sys/stat.h> // O_RDONLY
|
#include <sys/stat.h> // O_RDONLY
|
||||||
#include <fcntl.h> // open
|
|
||||||
#if HWY_OS_WIN
|
#if HWY_OS_WIN
|
||||||
#include <io.h> // read, write, close
|
|
||||||
#include <fileapi.h>
|
#include <fileapi.h>
|
||||||
|
#include <io.h> // read, write, close
|
||||||
#else
|
#else
|
||||||
#include <unistd.h> // read, write, close
|
#include <unistd.h> // read, write, close
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -113,7 +113,8 @@ hwy::uint128_t MakeKey(const char* string) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
namespace {
|
||||||
|
void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
||||||
std::vector<BlobIO>& requests) {
|
std::vector<BlobIO>& requests) {
|
||||||
// Split into chunks for load-balancing even if blob sizes vary.
|
// Split into chunks for load-balancing even if blob sizes vary.
|
||||||
constexpr size_t kChunkSize = 4 * 1024 * 1024;
|
constexpr size_t kChunkSize = 4 * 1024 * 1024;
|
||||||
|
|
@ -129,7 +130,7 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
||||||
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
|
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
struct IO {
|
struct IO {
|
||||||
// Returns size in bytes or 0.
|
// Returns size in bytes or 0.
|
||||||
|
|
@ -197,12 +198,6 @@ static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||||
class BlobStore {
|
class BlobStore {
|
||||||
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
|
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
|
||||||
|
|
||||||
// Blob offsets on disk and memory addresses are a multiple of this, because
|
|
||||||
// we pad the header and each blob's size. This matches CUDA alignment and the
|
|
||||||
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
|
|
||||||
// 128), which can help performance.
|
|
||||||
static constexpr size_t kAlign = 256;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// NOT including padding, so that we can also use ZeroFillPadding after
|
// NOT including padding, so that we can also use ZeroFillPadding after
|
||||||
// copying the header.
|
// copying the header.
|
||||||
|
|
@ -215,13 +210,13 @@ class BlobStore {
|
||||||
// blobs. Requires num_blobs_ to already be set, typically by reading
|
// blobs. Requires num_blobs_ to already be set, typically by reading
|
||||||
// sizeof(BlobStore) bytes from disk.
|
// sizeof(BlobStore) bytes from disk.
|
||||||
size_t PaddedHeaderSize() const {
|
size_t PaddedHeaderSize() const {
|
||||||
return hwy::RoundUpTo(HeaderSize(num_blobs_), kAlign);
|
return hwy::RoundUpTo(HeaderSize(num_blobs_), kBlobAlign);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns aligned offset and zero-fills between that and `offset`.
|
// Returns aligned offset and zero-fills between that and `offset`.
|
||||||
uint64_t ZeroFillPadding(uint64_t offset) {
|
uint64_t ZeroFillPadding(uint64_t offset) {
|
||||||
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
|
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
|
||||||
const uint64_t padded = hwy::RoundUpTo(offset, kAlign);
|
const uint64_t padded = hwy::RoundUpTo(offset, kBlobAlign);
|
||||||
hwy::ZeroBytes(bytes + offset, padded - offset);
|
hwy::ZeroBytes(bytes + offset, padded - offset);
|
||||||
return padded;
|
return padded;
|
||||||
}
|
}
|
||||||
|
|
@ -236,7 +231,7 @@ class BlobStore {
|
||||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||||
if (val.lo != offset) return __LINE__;
|
if (val.lo != offset) return __LINE__;
|
||||||
offset = ZeroFillPadding(offset + val.hi);
|
offset = hwy::RoundUpTo(offset + val.hi, kBlobAlign);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (offset != file_size_) return __LINE__;
|
if (offset != file_size_) return __LINE__;
|
||||||
|
|
@ -253,25 +248,24 @@ class BlobStore {
|
||||||
|
|
||||||
static std::vector<BlobIO> PrepareWriteRequests(
|
static std::vector<BlobIO> PrepareWriteRequests(
|
||||||
const hwy::uint128_t keys[], const hwy::Span<uint8_t> blobs[],
|
const hwy::uint128_t keys[], const hwy::Span<uint8_t> blobs[],
|
||||||
size_t num_blobs) {
|
size_t num_blobs, BlobStore* bs) {
|
||||||
// Sanity check and ensure the cast below is safe.
|
// Sanity check and ensure the cast below is safe.
|
||||||
HWY_ASSERT(num_blobs < (1ULL << 20));
|
HWY_ASSERT(num_blobs < (1ULL << 20));
|
||||||
|
|
||||||
// Allocate var-length header.
|
// Allocate var-length header.
|
||||||
const size_t header_size = HeaderSize(num_blobs);
|
const size_t header_size = HeaderSize(num_blobs);
|
||||||
const size_t padded_header_size = hwy::RoundUpTo(header_size, kAlign);
|
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
|
||||||
BlobStorePtr bs = Allocate(padded_header_size);
|
|
||||||
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
|
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
|
||||||
HWY_ASSERT(padded_header_end == padded_header_size);
|
HWY_ASSERT(padded_header_end == padded_header_size);
|
||||||
|
|
||||||
// All-zero buffer used to write padding to the file without copying the
|
// All-zero buffer used to write padding to the file without copying the
|
||||||
// input blobs.
|
// input blobs.
|
||||||
static uint8_t zeros[kAlign] = {0};
|
static uint8_t zeros[kBlobAlign] = {0};
|
||||||
|
|
||||||
// Total file size will be the header plus all padded blobs.
|
// Total file size will be the header plus all padded blobs.
|
||||||
uint64_t payload = 0;
|
uint64_t payload = 0;
|
||||||
for (size_t i = 0; i < num_blobs; ++i) {
|
for (size_t i = 0; i < num_blobs; ++i) {
|
||||||
payload += hwy::RoundUpTo(blobs[i].size(), kAlign);
|
payload += hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
|
||||||
}
|
}
|
||||||
const size_t total_size = padded_header_size + payload;
|
const size_t total_size = padded_header_size + payload;
|
||||||
|
|
||||||
|
|
@ -285,7 +279,7 @@ class BlobStore {
|
||||||
std::vector<BlobIO> requests;
|
std::vector<BlobIO> requests;
|
||||||
requests.reserve(1 + 2 * num_blobs);
|
requests.reserve(1 + 2 * num_blobs);
|
||||||
requests.emplace_back(/*offset=*/0, padded_header_size,
|
requests.emplace_back(/*offset=*/0, padded_header_size,
|
||||||
reinterpret_cast<uint8_t*>(bs.get()), 0);
|
reinterpret_cast<uint8_t*>(bs), 0);
|
||||||
|
|
||||||
// Fill second half of keys_ with offset/size and prepare IO requests.
|
// Fill second half of keys_ with offset/size and prepare IO requests.
|
||||||
uint64_t offset = padded_header_end;
|
uint64_t offset = padded_header_end;
|
||||||
|
|
@ -295,10 +289,10 @@ class BlobStore {
|
||||||
|
|
||||||
EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests);
|
EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests);
|
||||||
offset += blobs[i].size();
|
offset += blobs[i].size();
|
||||||
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kAlign);
|
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
|
||||||
if (padded_size != blobs[i].size()) {
|
if (padded_size != blobs[i].size()) {
|
||||||
const size_t padding = padded_size - blobs[i].size();
|
const size_t padding = padded_size - blobs[i].size();
|
||||||
HWY_ASSERT(padding <= kAlign);
|
HWY_ASSERT(padding <= kBlobAlign);
|
||||||
requests.emplace_back(offset, padding, zeros, 0);
|
requests.emplace_back(offset, padding, zeros, 0);
|
||||||
offset += padding;
|
offset += padding;
|
||||||
}
|
}
|
||||||
|
|
@ -418,8 +412,11 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
||||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||||
|
|
||||||
// Concatenate blobs in memory.
|
// Concatenate blobs in memory.
|
||||||
|
const size_t header_size = BlobStore::HeaderSize(keys_.size());
|
||||||
|
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
|
||||||
|
BlobStorePtr bs = BlobStore::Allocate(padded_header_size);
|
||||||
std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
|
std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
|
||||||
keys_.data(), blobs_.data(), keys_.size());
|
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
||||||
|
|
||||||
// Create/replace existing file.
|
// Create/replace existing file.
|
||||||
#if HWY_OS_WIN
|
#if HWY_OS_WIN
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,12 @@ using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
|
||||||
// 0 if successful, otherwise the line number of the failing check.
|
// 0 if successful, otherwise the line number of the failing check.
|
||||||
using BlobError = int;
|
using BlobError = int;
|
||||||
|
|
||||||
|
// Blob offsets on disk and memory addresses are a multiple of this, because
|
||||||
|
// we pad the header and each blob's size. This matches CUDA alignment and the
|
||||||
|
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
|
||||||
|
// 128), which can help performance.
|
||||||
|
static constexpr size_t kBlobAlign = 256;
|
||||||
|
|
||||||
struct BlobIO {
|
struct BlobIO {
|
||||||
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
|
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
|
||||||
: offset(offset), size(size), data(data), padding(padding) {}
|
: offset(offset), size(size), data(data), padding(padding) {}
|
||||||
|
|
|
||||||
|
|
@ -381,13 +381,14 @@ HWY_INLINE void Compress(const std::array<float, kCapacity>& in,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompresses `num` values from `compressed` starting at `compressed_ofs`.
|
// Decompresses `num` values from `compressed` starting at `compressed_ofs`.
|
||||||
template <typename MatT, size_t kCapacity, typename OutT>
|
template <typename ArrayT, typename OutT>
|
||||||
HWY_NOINLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
HWY_NOINLINE void Decompress(const ArrayT& compressed, size_t compressed_ofs,
|
||||||
size_t compressed_ofs, OutT* out, size_t num) {
|
OutT* out, size_t num) {
|
||||||
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
|
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||||
const hn::ScalableTag<OutT> d;
|
const hn::ScalableTag<OutT> d;
|
||||||
using Traits = CompressTraits<MatT>;
|
using Traits = CompressTraits<typename ArrayT::value_type>;
|
||||||
Traits::Decompress(d, kCapacity, compressed.data(), compressed_ofs, out, num);
|
Traits::Decompress(d, compressed.size(), compressed.data(), compressed_ofs,
|
||||||
|
out, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// As above, but with threading and benchmarking.
|
// As above, but with threading and benchmarking.
|
||||||
|
|
@ -395,7 +396,7 @@ template <typename MatT, size_t kCapacity, typename OutT>
|
||||||
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
||||||
size_t compressed_ofs, OutT* out, size_t num,
|
size_t compressed_ofs, OutT* out, size_t num,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
|
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
|
|
||||||
using Traits = CompressTraits<MatT>;
|
using Traits = CompressTraits<MatT>;
|
||||||
|
|
@ -407,7 +408,7 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
||||||
|
|
||||||
const size_t ofs = idx_batch * kBatch;
|
const size_t ofs = idx_batch * kBatch;
|
||||||
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
|
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
|
||||||
Traits::Decompress(d, compressed.NumElements(), compressed.data(),
|
Traits::Decompress(d, compressed.size(), compressed.data(),
|
||||||
compressed_ofs + ofs, out + ofs, num);
|
compressed_ofs + ofs, out + ofs, num);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -417,16 +418,28 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
||||||
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
|
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns dot product with `vec_aligned` of length `num`.
|
||||||
|
template <class DF, typename ArrayT, typename VecT>
|
||||||
|
HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
|
||||||
|
const VecT* vec_aligned, size_t num) {
|
||||||
|
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||||
|
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||||
|
using Traits = CompressTraits<typename ArrayT::value_type>;
|
||||||
|
return Traits::Dot(df, compressed.size(), compressed.data(), compressed_ofs,
|
||||||
|
vec_aligned, num);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns dot product with `vec_aligned` of length `num`.
|
// Returns dot product with `vec_aligned` of length `num`.
|
||||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||||
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
||||||
size_t compressed_ofs, const VecT* vec_aligned,
|
size_t compressed_ofs, const VecT* vec_aligned,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
|
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||||
using Traits = CompressTraits<MatT>;
|
using Traits = CompressTraits<MatT>;
|
||||||
return Traits::Dot(df, kCapacity, compressed.data(), compressed_ofs,
|
return (compressed.scale() * Traits::Dot(df, compressed.size(),
|
||||||
vec_aligned, num);
|
compressed.data(), compressed_ofs,
|
||||||
|
vec_aligned, num));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback used by ForeachTensor.
|
// Callback used by ForeachTensor.
|
||||||
|
|
@ -445,6 +458,12 @@ class Compressor {
|
||||||
compressed.CompressedSize());
|
compressed.CompressedSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddScales(float* scales, size_t len) {
|
||||||
|
if (len) {
|
||||||
|
writer_.Add(CacheKey<float>("scales"), scales, len * sizeof(scales[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
|
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
|
||||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
|
|
|
||||||
|
|
@ -71,10 +71,15 @@ class CompressedArray {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using value_type = MatT;
|
||||||
|
|
||||||
MatT* data() { return data_.data(); }
|
MatT* data() { return data_.data(); }
|
||||||
const MatT* data() const { return data_.data(); }
|
const MatT* data() const { return data_.data(); }
|
||||||
|
|
||||||
constexpr size_t NumElements() const { return kCapacity; }
|
float scale() const { return scale_[0]; }
|
||||||
|
void set_scale(float scale) { scale_[0] = scale; }
|
||||||
|
|
||||||
|
constexpr size_t size() const { return kCapacity; }
|
||||||
|
|
||||||
constexpr size_t CompressedSize() const {
|
constexpr size_t CompressedSize() const {
|
||||||
return NumCompressed() * sizeof(MatT);
|
return NumCompressed() * sizeof(MatT);
|
||||||
|
|
@ -82,6 +87,7 @@ class CompressedArray {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::array<MatT, NumCompressed()> data_;
|
std::array<MatT, NumCompressed()> data_;
|
||||||
|
float scale_[kBlobAlign / sizeof(float)];
|
||||||
};
|
};
|
||||||
|
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
|
|
@ -187,11 +193,21 @@ class CacheLoader {
|
||||||
|
|
||||||
err_ = reader_.Enqueue(CacheKey<MatT>(name), compressed.data(),
|
err_ = reader_.Enqueue(CacheKey<MatT>(name), compressed.data(),
|
||||||
compressed.CompressedSize());
|
compressed.CompressedSize());
|
||||||
|
compressed.set_scale(1.0f);
|
||||||
if (err_ != 0) {
|
if (err_ != 0) {
|
||||||
fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_);
|
fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LoadScales(float* scales, size_t len) {
|
||||||
|
if (0 != reader_.Enqueue(CacheKey<float>("scales"), scales,
|
||||||
|
len * sizeof(scales[0]))) {
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
scales[i] = 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns whether all tensors are successfully loaded from cache.
|
// Returns whether all tensors are successfully loaded from cache.
|
||||||
bool ReadAll(hwy::ThreadPool& pool) {
|
bool ReadAll(hwy::ThreadPool& pool) {
|
||||||
// reader_ invalid or any Enqueue failed
|
// reader_ invalid or any Enqueue failed
|
||||||
|
|
|
||||||
14
configs.h
14
configs.h
|
|
@ -30,6 +30,16 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "compression/sfp.h"
|
||||||
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
|
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
|
||||||
|
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
||||||
|
#ifndef GEMMA_WEIGHT_T
|
||||||
|
#define GEMMA_WEIGHT_T SfpStream
|
||||||
|
#endif // !GEMMA_WEIGHT_T
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||||
|
|
@ -45,6 +55,8 @@ struct ConfigGemma7B {
|
||||||
static constexpr int kKVHeads = 16; // standard MHA
|
static constexpr int kKVHeads = 16; // standard MHA
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr int kNumTensorScales = 0;
|
||||||
|
using WeightT = GEMMA_WEIGHT_T;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ConfigGemma2B {
|
struct ConfigGemma2B {
|
||||||
|
|
@ -57,6 +69,8 @@ struct ConfigGemma2B {
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr int kNumTensorScales = 0;
|
||||||
|
using WeightT = GEMMA_WEIGHT_T;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,9 @@
|
||||||
#include "gemma.h"
|
#include "gemma.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/app.h" // LoaderArgs
|
#include "util/app.h" // LoaderArgs
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
std::vector<int> tokenize(
|
std::vector<int> tokenize(
|
||||||
const std::string& prompt_string,
|
const std::string& prompt_string,
|
||||||
|
|
@ -43,8 +43,7 @@ int main(int argc, char** argv) {
|
||||||
hwy::ThreadPool pool(num_threads);
|
hwy::ThreadPool pool(num_threads);
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||||
loader.ModelType(), pool);
|
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
size_t pos = 0; // KV Cache position
|
size_t pos = 0; // KV Cache position
|
||||||
|
|
||||||
|
|
|
||||||
465
gemma.cc
465
gemma.cc
|
|
@ -25,12 +25,12 @@
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // Path
|
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // Path
|
||||||
|
|
||||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||||
// compile pass, whereas we want this defined in the first.
|
// compile pass, whereas we want this defined in the first.
|
||||||
|
|
@ -64,6 +64,12 @@
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
|
|
||||||
|
// Setting this to true disables fread() calls that read the model file.
|
||||||
|
constexpr bool kDryRunFread = false;
|
||||||
|
|
||||||
|
// Setting this to false will load and use uncompressed weights.
|
||||||
|
constexpr bool kWeightsAreCompressed = true;
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
|
|
@ -88,70 +94,145 @@ struct Layer {
|
||||||
std::array<float, kModelDim> pre_ffw_norm_scale;
|
std::array<float, kModelDim> pre_ffw_norm_scale;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
float ScaleWeights(float* data, size_t len) {
|
||||||
|
float maxabs = 0.0;
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
maxabs = std::max(maxabs, std::abs(data[i]));
|
||||||
|
}
|
||||||
|
const float kMaxRange = 1.875f;
|
||||||
|
if (maxabs <= kMaxRange) {
|
||||||
|
return 1.0f;
|
||||||
|
}
|
||||||
|
const float scale = maxabs / kMaxRange;
|
||||||
|
const float inv_scale = 1.0f / scale;
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
data[i] *= inv_scale;
|
||||||
|
}
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Array instead of single large allocation for parallel mem init. Split out of
|
||||||
|
// Weights so that only these pointers are initialized.
|
||||||
|
template <class TConfig>
|
||||||
|
struct LayerPointers {
|
||||||
|
explicit LayerPointers(hwy::ThreadPool& pool) {
|
||||||
|
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
||||||
|
this->layers[task] = hwy::AllocateAligned<Layer<TConfig>>(1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
using TLayer = Layer<TConfig>;
|
||||||
|
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
|
||||||
|
};
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct Weights {
|
struct Weights {
|
||||||
Weights() = default;
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
|
||||||
hwy::AlignedUniquePtr<Layer<TConfig>[]> layers; // kLayers
|
|
||||||
|
|
||||||
std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
|
std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
|
||||||
embedder_input_embedding;
|
embedder_input_embedding;
|
||||||
|
|
||||||
std::array<float, TConfig::kModelDim> final_norm_scale;
|
std::array<float, TConfig::kModelDim> final_norm_scale;
|
||||||
|
|
||||||
|
LayerPointers<TConfig> layer_ptrs;
|
||||||
|
|
||||||
|
std::array<float, TConfig::kNumTensorScales> scales;
|
||||||
|
|
||||||
|
const Layer<TConfig>* GetLayer(size_t layer) const {
|
||||||
|
return layer_ptrs.layers[layer].get();
|
||||||
|
}
|
||||||
|
Layer<TConfig>* GetLayer(size_t layer) {
|
||||||
|
return layer_ptrs.layers[layer].get();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only called if cached loading fails.
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
|
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
||||||
|
const Path& checkpoint, hwy::ThreadPool& pool,
|
||||||
|
bool scale_for_compression = false) {
|
||||||
PROFILER_ZONE("Startup.LoadWeights");
|
PROFILER_ZONE("Startup.LoadWeights");
|
||||||
using TWeights = Weights<TConfig>;
|
if (!std::filesystem::exists(checkpoint.path)) {
|
||||||
hwy::AlignedUniquePtr<TWeights> weights = hwy::MakeUniqueAligned<TWeights>();
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
weights->layers =
|
checkpoint.path.c_str());
|
||||||
hwy::MakeUniqueAlignedArray<Layer<TConfig>>(TConfig::kLayers);
|
|
||||||
|
|
||||||
if (checkpoint.path.empty()) {
|
|
||||||
HWY_ABORT(
|
|
||||||
"Loading --compressed_weights failed; we require a --weights argument. "
|
|
||||||
"Please see issue #11 on how to create this file.\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using TWeights = Weights<TConfig>;
|
||||||
|
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
|
||||||
|
hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
|
||||||
|
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||||
|
new (&weights->layer_ptrs) LayerPointers<TConfig>(pool);
|
||||||
|
|
||||||
|
size_t scale_pos = 0;
|
||||||
FILE* fptr;
|
FILE* fptr;
|
||||||
|
if constexpr (kDryRunFread) {
|
||||||
|
fprintf(stderr, "Dry-Run, not reading model-file.\n");
|
||||||
|
} else {
|
||||||
fptr = fopen(checkpoint.path.c_str(), "rb");
|
fptr = fopen(checkpoint.path.c_str(), "rb");
|
||||||
if (fptr == nullptr) {
|
if (fptr == nullptr) {
|
||||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
HWY_ABORT("Failed to open model file %s - does it exist?",
|
||||||
checkpoint.path.c_str());
|
checkpoint.path.c_str());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
uint64_t total_size = 0;
|
uint64_t total_size = 0;
|
||||||
ok &= 1 == fread(&(weights->embedder_input_embedding),
|
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
|
||||||
sizeof(weights->embedder_input_embedding), 1, fptr);
|
if (layer == -1) {
|
||||||
ok &= 1 == fread(&(weights->final_norm_scale),
|
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
|
||||||
sizeof(weights->final_norm_scale), 1, fptr);
|
} else {
|
||||||
total_size += sizeof(weights->embedder_input_embedding) +
|
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
|
||||||
sizeof(weights->final_norm_scale);
|
size, name);
|
||||||
|
}
|
||||||
|
if constexpr (!kDryRunFread) {
|
||||||
|
ok &= 1 == fread(var, size, 1, fptr);
|
||||||
|
total_size += size;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
do_fread(&(weights->embedder_input_embedding), -1, "embedder_input_embedding",
|
||||||
|
sizeof(weights->embedder_input_embedding));
|
||||||
|
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
||||||
|
sizeof(weights->final_norm_scale));
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
Layer<TConfig>* layer_view = &weights->layers[layer];
|
Layer<TConfig>* layer_view = weights->GetLayer(layer);
|
||||||
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
|
|
||||||
sizeof(layer_view->attn_vec_einsum_w), 1, fptr);
|
#define READ_WEIGHTS(name) \
|
||||||
ok &= 1 == fread(&layer_view->qkv_einsum_w,
|
do { \
|
||||||
sizeof(layer_view->qkv_einsum_w), 1, fptr);
|
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
|
||||||
ok &= 1 == fread(&layer_view->gating_einsum_w,
|
} while (0)
|
||||||
sizeof(layer_view->gating_einsum_w), 1, fptr);
|
|
||||||
ok &= 1 ==
|
#define SCALE_WEIGHTS(name) \
|
||||||
fread(&layer_view->linear_w, sizeof(layer_view->linear_w), 1, fptr);
|
do { \
|
||||||
ok &= 1 == fread(&layer_view->pre_attention_norm_scale,
|
if (ok && !kDryRunFread && scale_for_compression) { \
|
||||||
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
|
weights->scales[scale_pos++] = \
|
||||||
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
|
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
|
||||||
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
|
} \
|
||||||
total_size += sizeof(*layer_view);
|
} while (0)
|
||||||
|
// Make sure we don't have uninitialized memory.
|
||||||
|
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
||||||
|
READ_WEIGHTS(attn_vec_einsum_w);
|
||||||
|
READ_WEIGHTS(qkv_einsum_w);
|
||||||
|
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||||
|
SCALE_WEIGHTS(qkv_einsum_w);
|
||||||
|
READ_WEIGHTS(gating_einsum_w);
|
||||||
|
READ_WEIGHTS(linear_w);
|
||||||
|
SCALE_WEIGHTS(gating_einsum_w);
|
||||||
|
SCALE_WEIGHTS(linear_w);
|
||||||
|
READ_WEIGHTS(pre_attention_norm_scale);
|
||||||
|
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||||
|
#undef READ_WEIGHTS
|
||||||
}
|
}
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
HWY_ABORT("Failed to read from %s - might be a directory, or too small? "
|
HWY_ABORT(
|
||||||
"expected size: %d kB", checkpoint.path.c_str(),
|
"Failed to read from %s - might be a directory, or too small? "
|
||||||
static_cast<uint32_t>(total_size >> 10));
|
"expected size: %d kB",
|
||||||
|
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
|
||||||
}
|
}
|
||||||
|
if (!kDryRunFread) {
|
||||||
HWY_ASSERT(0 == fclose(fptr));
|
HWY_ASSERT(0 == fclose(fptr));
|
||||||
return weights;
|
if (scale_for_compression) {
|
||||||
|
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights_u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
|
|
@ -159,18 +240,19 @@ struct CompressedLayer {
|
||||||
// No ctor/dtor, allocated via AllocateAligned.
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
|
||||||
using TLayer = gcpp::Layer<TConfig>;
|
using TLayer = gcpp::Layer<TConfig>;
|
||||||
|
using WeightT = typename TConfig::WeightT;
|
||||||
|
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
|
|
||||||
// Compressed Parameters
|
// Compressed Parameters
|
||||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||||
CompressedArray<hwy::bfloat16_t, kModelDim> c_pre_attention_norm_scale;
|
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||||
CompressedArray<hwy::bfloat16_t, kModelDim> c_pre_ffw_norm_scale;
|
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||||
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> c_gating_einsum_w;
|
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||||
CompressedArray<WeightT, kModelDim * kFFHiddenDim> c_linear_w;
|
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||||
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> c_qkv_einsum_w;
|
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
|
||||||
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> c_attn_vec_einsum_w;
|
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Array instead of single large allocation for parallel mem init. Split out of
|
// Array instead of single large allocation for parallel mem init. Split out of
|
||||||
|
|
@ -193,21 +275,25 @@ struct CompressedWeights {
|
||||||
// No ctor/dtor, allocated via AllocateAligned.
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
|
||||||
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
|
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
|
||||||
c_embedder_input_embedding;
|
embedder_input_embedding;
|
||||||
|
|
||||||
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> c_final_norm_scale;
|
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
|
||||||
|
|
||||||
// Must be last so that the other arrays remain aligned.
|
// Must be last so that the other arrays remain aligned.
|
||||||
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
||||||
|
|
||||||
const CompressedLayer<TConfig>* CLayer(size_t layer) const {
|
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
|
||||||
return c_layer_ptrs.c_layers[layer].get();
|
return c_layer_ptrs.c_layers[layer].get();
|
||||||
}
|
}
|
||||||
CompressedLayer<TConfig>* CLayer(size_t layer) {
|
CompressedLayer<TConfig>* GetLayer(size_t layer) {
|
||||||
return c_layer_ptrs.c_layers[layer].get();
|
return c_layer_ptrs.c_layers[layer].get();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
||||||
|
Weights<TConfig>>;
|
||||||
|
|
||||||
// Aligned.
|
// Aligned.
|
||||||
template <class TConfig, size_t TBatchSize>
|
template <class TConfig, size_t TBatchSize>
|
||||||
struct Activations {
|
struct Activations {
|
||||||
|
|
@ -272,16 +358,27 @@ KVCache CreateKVCache(Model type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <class Config>
|
||||||
|
void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
|
||||||
|
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||||
|
}
|
||||||
|
template <class Config>
|
||||||
|
void DeleteLayersPtrs(Weights<Config>* weights) {
|
||||||
|
weights->layer_ptrs.~LayerPointers<Config>();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
struct GemmaImpl : public GemmaInterface {
|
struct GemmaImpl : public GemmaInterface {
|
||||||
GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
~GemmaImpl() {
|
~GemmaImpl() {
|
||||||
using CWeights = CompressedWeights<Config>;
|
WeightsT<Config>* weights =
|
||||||
CWeights* c_weights = reinterpret_cast<CWeights*>(compressed_weights.get());
|
reinterpret_cast<WeightsT<Config>*>(weights_u8.get());
|
||||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
DeleteLayersPtrs(weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||||
|
|
@ -296,7 +393,7 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
int verbosity) override;
|
int verbosity) override;
|
||||||
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||||
};
|
};
|
||||||
|
|
@ -309,11 +406,11 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
const CompressedLayer<TConfig>* c_layer,
|
const LayerT* layer_weights, KVCache& kv_cache,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
HWY_DASSERT(batch_idx < kBatchSize);
|
HWY_DASSERT(batch_idx < kBatchSize);
|
||||||
|
|
@ -329,26 +426,25 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
static const float kQueryScale =
|
static const float kQueryScale =
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||||
|
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
|
|
||||||
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||||
|
|
||||||
MatVecLoop<kQKVDim, kModelDim>(
|
MatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
|
||||||
c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim,
|
head_offset + 0 * kQKVDim * kModelDim, x, q);
|
||||||
activations.pre_att_rms_out.data() + batch_offset, q);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto ProjKV =
|
auto ProjKV = [&](size_t k_offset, size_t v_offset,
|
||||||
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
|
size_t kv_offset) HWY_ATTR {
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset;
|
||||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset;
|
||||||
activations.pre_att_rms_out.data() + batch_offset,
|
|
||||||
kv_cache.key_cache.get() + kv_offset,
|
|
||||||
kv_cache.value_cache.get() + kv_offset);
|
|
||||||
|
|
||||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
||||||
|
v_offset, x, k, v);
|
||||||
|
|
||||||
|
Rope(k, kQKVDim, pos);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||||
|
|
@ -388,7 +484,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
head == 0
|
head == 0
|
||||||
? activations.att_post2.data() + batch_idx * kModelDim
|
? activations.att_post2.data() + batch_idx * kModelDim
|
||||||
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||||
MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w,
|
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
||||||
head * kModelDim * kQKVDim, att_out,
|
head * kModelDim * kQKVDim, att_out,
|
||||||
head_out);
|
head_out);
|
||||||
};
|
};
|
||||||
|
|
@ -431,9 +527,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig, size_t kBatchSize>
|
template <size_t kBatchSize, typename LayerT, typename TConfig>
|
||||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
size_t batch_idx, const CompressedLayer<TConfig>* c_layer,
|
size_t batch_idx, const LayerT* layer_weights,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
HWY_DASSERT(batch_idx < kBatchSize);
|
HWY_DASSERT(batch_idx < kBatchSize);
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -449,12 +545,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
|
|
||||||
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
||||||
// but separating them could help on NUMA e.g. multiple sockets.
|
// but separating them could help on NUMA e.g. multiple sockets.
|
||||||
MatVec<kFFHiddenDim, kModelDim>(c_layer->c_gating_einsum_w,
|
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w,
|
||||||
kFFHiddenDim * kModelDim, vec, out_mul,
|
kFFHiddenDim * kModelDim, vec, out_mul,
|
||||||
pool);
|
pool);
|
||||||
|
|
||||||
// Gate, will go through the nonlinearity.
|
// Gate, will go through the nonlinearity.
|
||||||
MatVec<kFFHiddenDim, kModelDim>(c_layer->c_gating_einsum_w, 0, vec, out,
|
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, 0, vec, out,
|
||||||
pool);
|
pool);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
@ -467,7 +563,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
|
|
||||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||||
MatVec<kModelDim, kFFHiddenDim>(
|
MatVec<kModelDim, kFFHiddenDim>(
|
||||||
c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -486,9 +582,9 @@ GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig, size_t kBatchSize>
|
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
||||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
const CompressedWeights<TConfig>& c_weights,
|
const WeightArrayT& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool) {
|
hwy::ThreadPool& inner_pool) {
|
||||||
|
|
@ -500,22 +596,22 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
const int token = tokens[token_idx];
|
const int token = tokens[token_idx];
|
||||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
});
|
});
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||||
c_layer->c_pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
Attention<TConfig, kBatchSize>(pos, token_idx, layer, activations,
|
Attention<kBatchSize>(pos, token_idx, layer, activations, layer_weights,
|
||||||
c_layer, kv_cache, pool);
|
kv_cache, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
||||||
|
|
@ -525,10 +621,10 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||||
c_layer->c_pre_ffw_norm_scale.data(),
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
FFW<TConfig, kBatchSize>(activations, token_idx, c_layer, inner_pool);
|
FFW<kBatchSize>(activations, token_idx, layer_weights, inner_pool);
|
||||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
});
|
});
|
||||||
|
|
@ -536,21 +632,20 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
|
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
RMSNormInplace(c_weights.c_final_norm_scale.data(),
|
RMSNormInplace(weights.final_norm_scale.data(),
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// n = 1 specialization
|
// n = 1 specialization
|
||||||
template <class TConfig>
|
template <typename WeightArrayT, class TConfig>
|
||||||
void Transformer(int token, size_t pos,
|
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
const CompressedWeights<TConfig>& c_weights,
|
|
||||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
static constexpr size_t kLayers = TConfig::kLayers;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
||||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
|
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
|
|
@ -558,17 +653,18 @@ void Transformer(int token, size_t pos,
|
||||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||||
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(),
|
RMSNorm(activations.x.data(),
|
||||||
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
Attention<TConfig, 1>(pos, 0, layer, activations, c_layer, kv_cache, pool);
|
Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool);
|
||||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||||
RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(),
|
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||||
FFW<TConfig, 1>(activations, /* batch_idx = */ 0, c_layer, pool);
|
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
|
||||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
||||||
}
|
}
|
||||||
RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(),
|
RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
|
||||||
kModelDim);
|
kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -609,9 +705,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
||||||
*gemma.prefill.get();
|
*gemma.prefill.get();
|
||||||
const CompressedWeights<TConfig>& c_weights =
|
|
||||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
const WeightsT<TConfig>& weights =
|
||||||
gemma.compressed_weights.get());
|
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||||
|
|
||||||
size_t prompt_size = prompt.size();
|
size_t prompt_size = prompt.size();
|
||||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
||||||
|
|
@ -643,9 +739,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
||||||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||||
const int* batch_tokens = prompt.data() + pos_offset;
|
const int* batch_tokens = prompt.data() + pos_offset;
|
||||||
Prefill<TConfig, kPrefillBatchSize>(batch_tokens, batch_size, pos,
|
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||||
c_weights, prefill_activations,
|
prefill_activations, kv_cache, pool, inner_pool);
|
||||||
kv_cache, pool, inner_pool);
|
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
stream_token(batch_tokens[idx], 0.0f);
|
stream_token(batch_tokens[idx], 0.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -672,7 +767,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
for (size_t generate_pos = 0;
|
for (size_t generate_pos = 0;
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
|
||||||
float* final_activation = activations.x.data();
|
float* final_activation = activations.x.data();
|
||||||
// The condition below is always true if we are doing Prefill above.
|
// The condition below is always true if we are doing Prefill above.
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
|
|
@ -680,8 +775,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
if (pos_offset >= prompt_size - 1) {
|
if (pos_offset >= prompt_size - 1) {
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Generation phase
|
// Generation phase
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
||||||
c_weights.c_embedder_input_embedding, 0, final_activation,
|
0, final_activation,
|
||||||
activations.logits.data(), pool);
|
activations.logits.data(), pool);
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
|
|
@ -743,52 +838,37 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
||||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||||
func("c_embedding",
|
func("c_embedding",
|
||||||
weights ? weights->embedder_input_embedding.data() : nullptr,
|
weights ? weights->embedder_input_embedding.data() : nullptr,
|
||||||
c_weights.c_embedder_input_embedding);
|
c_weights.embedder_input_embedding);
|
||||||
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
|
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
|
||||||
c_weights.c_final_norm_scale);
|
c_weights.final_norm_scale);
|
||||||
|
|
||||||
char name[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < static_cast<int>(TConfig::kLayers);
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
++layer_idx) {
|
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
Layer<TConfig>* layer = weights ? &weights->layers[idx] : nullptr;
|
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
||||||
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(idx);
|
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
||||||
|
|
||||||
snprintf(name, sizeof(name), "pre_ff_ns_%d", layer_idx);
|
#define CALL_FUNC(name, member) \
|
||||||
func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr,
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
c_layer->c_pre_ffw_norm_scale);
|
func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
|
||||||
|
|
||||||
snprintf(name, sizeof(name), "gating_ein_%d", layer_idx);
|
CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
||||||
func(name, layer ? layer->gating_einsum_w.data() : nullptr,
|
CALL_FUNC("gating_ein", gating_einsum_w);
|
||||||
c_layer->c_gating_einsum_w);
|
CALL_FUNC("linear_w", linear_w);
|
||||||
|
CALL_FUNC("qkv_ein", qkv_einsum_w);
|
||||||
snprintf(name, sizeof(name), "linear_w_%d", layer_idx);
|
CALL_FUNC("att_ein", attn_vec_einsum_w);
|
||||||
func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w);
|
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||||
snprintf(name, sizeof(name), "qkv_ein_%d", layer_idx);
|
#undef CALL_FUNC
|
||||||
|
|
||||||
func(name, layer ? layer->qkv_einsum_w.data() : nullptr,
|
|
||||||
c_layer->c_qkv_einsum_w);
|
|
||||||
snprintf(name, sizeof(name), "att_ein_%d", layer_idx);
|
|
||||||
|
|
||||||
func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr,
|
|
||||||
c_layer->c_attn_vec_einsum_w);
|
|
||||||
|
|
||||||
snprintf(name, sizeof(name), "pre_att_ns_%d", layer_idx);
|
|
||||||
func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr,
|
|
||||||
c_layer->c_pre_attention_norm_scale);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
||||||
const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) {
|
const Path& weights, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Startup.LoadCache");
|
PROFILER_ZONE("Startup.LoadCache");
|
||||||
|
if (!std::filesystem::exists(weights.path)) {
|
||||||
if (!std::filesystem::exists(weights_path.path) &&
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
!std::filesystem::exists(cache.path)) {
|
weights.path.c_str());
|
||||||
HWY_ABORT(
|
|
||||||
"Either the model weights (--weights) or cached compressed weights "
|
|
||||||
"(--compressed_weights) must exist.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate compressed weights.
|
// Allocate compressed weights.
|
||||||
|
|
@ -798,32 +878,49 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
||||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||||
|
|
||||||
// First attempt to load them from cache, without requiring weights.
|
std::array<float, TConfig::kNumTensorScales> scales;
|
||||||
CacheLoader loader(cache.path.c_str());
|
CacheLoader loader(weights.path.c_str());
|
||||||
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
||||||
if (loader.ReadAll(pool)) return c_weights_u8;
|
loader.LoadScales(scales.data(), scales.size());
|
||||||
|
if (!loader.ReadAll(pool)) {
|
||||||
// Get weights, compress, and store in cache.
|
HWY_ABORT("Failed to load model weights.");
|
||||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
}
|
||||||
LoadWeights<TConfig>(weights_path);
|
if (TConfig::kNumTensorScales > 0) {
|
||||||
Compressor compressor(pool);
|
size_t scale_pos = 0;
|
||||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
compressor.WriteAll(pool, cache.path.c_str());
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
|
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
|
||||||
|
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->linear_w.set_scale(scales[scale_pos++]);
|
||||||
|
}
|
||||||
|
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
||||||
|
}
|
||||||
return c_weights_u8;
|
return c_weights_u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type-erased because this function is called via a function pointer.
|
// Type-erased because this function is called via a function pointer.
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
|
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
|
||||||
gcpp::Model model, const Path& weights, const Path& compressed_weights,
|
gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) {
|
||||||
|
switch (model) {
|
||||||
|
case Model::GEMMA_2B:
|
||||||
|
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
|
||||||
|
case Model::GEMMA_7B:
|
||||||
|
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
|
||||||
|
default:
|
||||||
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
|
||||||
|
const Path& weights,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
return GetCompressedWeights<ConfigGemma2B>(weights, compressed_weights,
|
return LoadWeights<ConfigGemma2B>(weights, pool);
|
||||||
pool);
|
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
return GetCompressedWeights<ConfigGemma7B>(weights, compressed_weights,
|
return LoadWeights<ConfigGemma7B>(weights, pool);
|
||||||
pool);
|
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
@ -846,18 +943,22 @@ void CompressWeights(const Path& weights_path,
|
||||||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||||
|
|
||||||
// Get weights, compress, and store.
|
// Get weights, compress, and store.
|
||||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
|
||||||
LoadWeights<TConfig>(weights_path);
|
const hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
|
||||||
|
LoadWeights<TConfig>(weights_path, pool, scale_for_compression);
|
||||||
|
Weights<TConfig>* weights =
|
||||||
|
reinterpret_cast<Weights<TConfig>*>(weights_u8.get());
|
||||||
Compressor compressor(pool);
|
Compressor compressor(pool);
|
||||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
ForEachTensor<TConfig>(weights, *c_weights, compressor);
|
||||||
|
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||||
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
|
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
|
||||||
|
|
||||||
|
weights->layer_ptrs.~LayerPointers<TConfig>();
|
||||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompressWeightsT(gcpp::Model model, const Path& weights,
|
void CompressWeightsT(gcpp::Model model, const Path& weights,
|
||||||
const Path& compressed_weights,
|
const Path& compressed_weights, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
|
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
|
||||||
|
|
@ -877,7 +978,8 @@ HWY_AFTER_NAMESPACE();
|
||||||
#if HWY_ONCE
|
#if HWY_ONCE
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(GetCompressedWeightsT);
|
HWY_EXPORT(LoadCompressedWeightsT);
|
||||||
|
HWY_EXPORT(LoadWeightsT);
|
||||||
HWY_EXPORT(CompressWeightsT);
|
HWY_EXPORT(CompressWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
HWY_EXPORT(Generate2B);
|
||||||
HWY_EXPORT(Generate7B);
|
HWY_EXPORT(Generate7B);
|
||||||
|
|
@ -892,10 +994,9 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||||
template <class Config>
|
template <class Config>
|
||||||
GemmaImpl<Config>::GemmaImpl(
|
GemmaImpl<Config>::GemmaImpl(
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8, hwy::ThreadPool& pool)
|
||||||
hwy::ThreadPool& pool)
|
|
||||||
: tokenizer(std::move(tokenizer)),
|
: tokenizer(std::move(tokenizer)),
|
||||||
compressed_weights(std::move(compressed_weights)),
|
weights_u8(std::move(weights_u8)),
|
||||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||||
|
|
||||||
|
|
@ -922,10 +1023,8 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
const Path& weights_path, Model model_type, ModelTraining training,
|
hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool)
|
|
||||||
: model_training(training) {
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
|
@ -934,16 +1033,21 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
HWY_ABORT("Failed to load the tokenizer file.");
|
HWY_ABORT("Failed to load the tokenizer file.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
|
|
||||||
model_type, weights_path, compressed_weights_path, pool);
|
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
||||||
|
if constexpr (kWeightsAreCompressed) {
|
||||||
|
weights_u8 =
|
||||||
|
HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool);
|
||||||
|
} else {
|
||||||
|
weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool);
|
||||||
|
}
|
||||||
|
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
impl_.reset(
|
impl_.reset(new GemmaImpl<ConfigGemma2B>(tokenizer, weights_u8, pool));
|
||||||
new GemmaImpl<ConfigGemma2B>(tokenizer, compressed_weights, pool));
|
|
||||||
break;
|
break;
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
impl_.reset(
|
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
|
||||||
new GemmaImpl<ConfigGemma7B>(tokenizer, compressed_weights, pool));
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||||
|
|
@ -981,10 +1085,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
const Path& compressed_weights,
|
const Path& compressed_weights, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
HWY_DYNAMIC_DISPATCH(CompressWeightsT)
|
||||||
HWY_DYNAMIC_DISPATCH(CompressWeightsT)(
|
(model, weights, compressed_weights, pool);
|
||||||
model, weights, compressed_weights, pool);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
19
gemma.h
19
gemma.h
|
|
@ -24,22 +24,18 @@
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress.h" // SfpStream/NuqStream
|
#include "compression/compress.h" // SfpStream/NuqStream
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // Path
|
#include "configs.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // Path
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
|
using GemmaWeightT = GEMMA_WEIGHT_T;
|
||||||
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
|
||||||
#ifndef GEMMA_WEIGHT_T
|
|
||||||
#define GEMMA_WEIGHT_T SfpStream
|
|
||||||
#endif // !GEMMA_WEIGHT_T
|
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
|
||||||
|
|
||||||
using EmbedderInputT = hwy::bfloat16_t;
|
using EmbedderInputT = hwy::bfloat16_t;
|
||||||
constexpr size_t kPrefillBatchSize = 16;
|
constexpr size_t kPrefillBatchSize = 16;
|
||||||
constexpr bool kSystemPrompt = false;
|
constexpr bool kSystemPrompt = false;
|
||||||
|
|
@ -65,13 +61,11 @@ struct RuntimeConfig {
|
||||||
struct GemmaInterface;
|
struct GemmaInterface;
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
const Path& weights_path, Model model_type, ModelTraining training,
|
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
gcpp::ModelTraining model_training;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||||
|
|
@ -99,8 +93,7 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
const StreamFunc& stream_token, std::mt19937& gen);
|
const StreamFunc& stream_token, std::mt19937& gen);
|
||||||
|
|
||||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
const Path& compressed_weights,
|
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||||
hwy::ThreadPool& pool);
|
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -369,6 +369,7 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Compress(content, ws, mat, pool);
|
Compress(content, ws, mat, pool);
|
||||||
|
mat.set_scale(1.0f);
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
23
run.cc
23
run.cc
|
|
@ -29,14 +29,14 @@
|
||||||
#include "gemma.h" // Gemma
|
#include "gemma.h" // Gemma
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // HasHelp
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/per_target.h"
|
#include "hwy/per_target.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // HasHelp
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -66,7 +66,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||||
<< "Compiled config : " << CompiledConfig() << "\n"
|
<< "Compiled config : " << CompiledConfig() << "\n"
|
||||||
<< "Weight Type : "
|
<< "Weight Type : "
|
||||||
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
|
<< gcpp::TypeName(gcpp::GemmaWeightT()) << "\n"
|
||||||
<< "EmbedderInput Type : "
|
<< "EmbedderInput Type : "
|
||||||
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
||||||
}
|
}
|
||||||
|
|
@ -93,10 +93,11 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
std::cerr << "\n";
|
std::cerr << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const InferenceArgs& args, int verbosity,
|
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||||
|
std::string& eot_line) {
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
int abs_pos = 0; // absolute token index over all turns
|
int abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
|
|
@ -177,7 +178,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model.model_training == ModelTraining::GEMMA_IT) {
|
if (training == ModelTraining::GEMMA_IT) {
|
||||||
// For instruction-tuned models: add control tokens.
|
// For instruction-tuned models: add control tokens.
|
||||||
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
||||||
"<end_of_turn>\n<start_of_turn>model\n";
|
"<end_of_turn>\n<start_of_turn>model\n";
|
||||||
|
|
@ -232,8 +233,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
|
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||||
loader.ModelType(), loader.ModelTraining(), pool);
|
|
||||||
|
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
|
@ -265,7 +265,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(
|
ReplGemma(
|
||||||
model, kv_cache, pool, inner_pool, inference, app.verbosity,
|
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference,
|
||||||
|
app.verbosity,
|
||||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
39
util/app.h
39
util/app.h
|
|
@ -36,9 +36,9 @@
|
||||||
#include "configs.h"
|
#include "configs.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "gemma.h"
|
#include "gemma.h"
|
||||||
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -151,7 +151,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() const {
|
const char* Validate() {
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
if (model_type.empty()) {
|
if (model_type.empty()) {
|
||||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||||
|
|
@ -165,37 +165,42 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
if (tokenizer.path.empty()) {
|
if (tokenizer.path.empty()) {
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||||
}
|
}
|
||||||
if (compressed_weights.path.empty()) {
|
if (!compressed_weights.path.empty()) {
|
||||||
return "Missing --compressed_weights flag, a file for the compressed "
|
if (weights.path.empty()) {
|
||||||
"model.";
|
weights = compressed_weights;
|
||||||
|
} else {
|
||||||
|
return "Only one of --weights and --compressed_weights can be "
|
||||||
|
"specified. To create compressed weights use the compress_weights "
|
||||||
|
"tool.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (weights.path.empty()) {
|
||||||
|
return "Missing --weights flag, a file for the model weights.";
|
||||||
|
}
|
||||||
|
if (!weights.exists()) {
|
||||||
|
return "Can't open file specified with --weights flag.";
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Path tokenizer;
|
Path tokenizer;
|
||||||
Path weights; // uncompressed weights file location
|
Path weights; // weights file location
|
||||||
Path compressed_weights; // compressed weights file location
|
Path compressed_weights;
|
||||||
std::string model_type;
|
std::string model_type;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
"Path name of tokenizer model file.\n Required argument.");
|
"Path name of tokenizer model file.\n Required argument.");
|
||||||
visitor(
|
visitor(weights, "weights", Path(),
|
||||||
compressed_weights, "compressed_weights", Path(),
|
"Path name of model weights (.sbs) file.\n Required argument.");
|
||||||
"Path name of compressed weights file, regenerated from `--weights` "
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"file if "
|
"Alias for --weights.");
|
||||||
"the compressed weights file does not exist.\n Required argument.");
|
|
||||||
visitor(model_type, "model", std::string(),
|
visitor(model_type, "model", std::string(),
|
||||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(weights, "weights", Path(),
|
|
||||||
"Path name of model weights (.sbs) file. Only required if "
|
|
||||||
"compressed_weights file is not present and needs to be "
|
|
||||||
"regenerated. This parameter is only required for compressing "
|
|
||||||
"new model weight exports, otherwise it is not needed.");
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue