mirror of https://github.com/google/gemma.cpp.git
Merge branch 'google:main' into main
This commit is contained in:
commit
5d72c911fc
|
|
@ -0,0 +1 @@
|
|||
BasedOnStyle: Google
|
||||
|
|
@ -1,22 +1,30 @@
|
|||
name: Build
|
||||
name: build
|
||||
|
||||
# Trigger on push or via manual dispatch.
|
||||
on: [push, workflow_dispatch]
|
||||
# Trigger on push, pull request, or via manual dispatch.
|
||||
on: [push, pull_request, workflow_dispatch]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{matrix.os}}
|
||||
name: ${{ matrix.os }} ${{ matrix.type }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: ${{ matrix.os }} (${{ matrix.preset }}) ${{ matrix.build_type }}
|
||||
timeout-minutes: 30
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
type: ['Release']
|
||||
os: ['ubuntu-latest']
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
build_type: ['Release']
|
||||
preset: ['make', 'windows']
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
preset: windows
|
||||
- os: macos-latest
|
||||
preset: windows
|
||||
- os: windows-latest
|
||||
preset: make
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.preset }}-${{ matrix.build_type }}
|
||||
cancel-in-progress: true
|
||||
|
||||
steps:
|
||||
|
|
@ -26,20 +34,23 @@ jobs:
|
|||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
|
||||
# Install CMake
|
||||
- uses: lukka/get-cmake@latest
|
||||
- name: Configure CMake
|
||||
run: >
|
||||
cmake --preset ${{ matrix.preset }}
|
||||
-S ${{ github.workspace }} -B ${{ github.workspace }}/build
|
||||
-D CMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
-D CMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
|
||||
# Build via CMake
|
||||
# Reference: https://github.com/lukka/run-cmake/blob/v3/action.yml
|
||||
- name: Build via cmake
|
||||
uses: lukka/run-cmake@v3
|
||||
- name: Build
|
||||
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }}
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
cmakeListsOrSettingsJson: CMakeListsTxtAdvanced
|
||||
cmakeAppendedArgs: >
|
||||
-D CMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
buildWithCMake: true
|
||||
# Explicitly list build targets here.
|
||||
# Building "all" includes test executables and takes much longer.
|
||||
buildWithCMakeArgs: "-- gemma"
|
||||
buildDirectory: '${{ github.workspace }}/build'
|
||||
name: gemma-${{ matrix.os }}-${{ matrix.preset }}-${{ matrix.build_type }}
|
||||
path: |
|
||||
${{ github.workspace }}/build/${{ matrix.build_type }}/gemma.exe
|
||||
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
|
||||
${{ github.workspace }}/build/gemma
|
||||
${{ github.workspace }}/build/libgemma.a
|
||||
|
|
|
|||
51
BUILD.bazel
51
BUILD.bazel
|
|
@ -114,54 +114,3 @@ cc_binary(
|
|||
"//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
# copybara:strip_begin
|
||||
cc_binary(
|
||||
name = "run_csv",
|
||||
srcs = [
|
||||
"run_csv.cc",
|
||||
],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":gemma_lib",
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"//third_party/riegeli/bytes:file_reader",
|
||||
"//third_party/riegeli/bytes:file_writer",
|
||||
"//third_party/riegeli/csv:csv_reader",
|
||||
"//third_party/riegeli/csv:csv_writer",
|
||||
],
|
||||
)
|
||||
|
||||
gensignature(
|
||||
name = "gemma_sign",
|
||||
srcs = [":gemma"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "benchmarks",
|
||||
size = "large",
|
||||
srcs = [
|
||||
"benchmarks.cc",
|
||||
],
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
":app",
|
||||
":gemma_lib",
|
||||
"//third_party/benchmark",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
# copybara:strip_end
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl meeds tp be installed by sentencepiece. This will only happen if
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
@ -43,14 +43,13 @@ set(SOURCES
|
|||
util/args.h
|
||||
)
|
||||
|
||||
add_compile_options($<$<CONFIG:Release>:-O2>)
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
# Allowable types for WEIGHT_TYPE:
|
||||
# float - slow, not recommended
|
||||
# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway
|
||||
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
|
||||
# SfpStream - 8-bit switched floating point (recommended)
|
||||
# NuqStream - experimental, work-in-progress
|
||||
option(WEIGHT_TYPE "Set weight type" "")
|
||||
|
|
@ -68,12 +67,17 @@ target_link_libraries(gemma hwy hwy_contrib sentencepiece)
|
|||
target_include_directories(gemma PRIVATE ./)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(gemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(gemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
||||
## Library Target
|
||||
|
||||
add_library(libgemma ${SOURCES})
|
||||
set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(libgemma PROPERTIES PREFIX "")
|
||||
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(libgemma PUBLIC ./)
|
||||
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
||||
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
{
|
||||
"version": 3,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 11,
|
||||
"patch": 0
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "__defaults__",
|
||||
"hidden": true,
|
||||
"binaryDir": "${sourceDir}/build"
|
||||
},
|
||||
{
|
||||
"name": "make",
|
||||
"inherits": "__defaults__",
|
||||
"displayName": "Make",
|
||||
"description": "Unix Makefiles",
|
||||
"generator": "Unix Makefiles",
|
||||
"binaryDir": "${sourceDir}/build"
|
||||
},
|
||||
{
|
||||
"name": "windows",
|
||||
"inherits": "__defaults__",
|
||||
"displayName": "Windows",
|
||||
"description": "Visual Studio 2022 with Clang/LLVM frontend",
|
||||
"generator": "Visual Studio 17 2022",
|
||||
"toolset": "ClangCL",
|
||||
"condition": {
|
||||
"type": "equals",
|
||||
"lhs": "${hostSystemName}",
|
||||
"rhs": "Windows"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
{
|
||||
"name": "__defaults__",
|
||||
"hidden": true,
|
||||
"targets": [
|
||||
"gemma",
|
||||
"libgemma"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "make",
|
||||
"inherits": "__defaults__",
|
||||
"displayName": "Unix Makefiles",
|
||||
"configurePreset": "make"
|
||||
},
|
||||
{
|
||||
"name": "windows",
|
||||
"inherits": "__defaults__",
|
||||
"displayName": "Windows",
|
||||
"configuration": "Release",
|
||||
"configurePreset": "windows"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -70,3 +70,33 @@ The implementation code is roughly split into 4 layers, from high to low level:
|
|||
|
||||
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
|
||||
highway) supporting the implementations in (3).
|
||||
|
||||
Besides these layers, supporting utilities are:
|
||||
|
||||
- `compression/` - model compression operations. The 8-bit switched floating
|
||||
point model conversion is here.
|
||||
- `util/` - command line argument handling and any other utilities.
|
||||
|
||||
## Style and Formatting
|
||||
|
||||
A `.clang-format` configuration is provided with our defaults, please run source
|
||||
files through `clang-format` (or a formatter that produces equivalent behavior)
|
||||
before finalizing PR for submission.
|
||||
|
||||
## Compile-Time Flags (Advanced)
|
||||
|
||||
There are several compile-time flags to be aware of (note these may or may not
|
||||
be exposed to the build system):
|
||||
|
||||
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
|
||||
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
|
||||
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
|
||||
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
|
||||
`gemma.h`.
|
||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
||||
through `CMakeLists.txt` yet.
|
||||
|
||||
In the medium term both of these will likely be deprecated in favor of handling
|
||||
options at runtime - allowing for multiple weight compression schemes in a single
|
||||
build and dynamically resizes the KV cache as needed.
|
||||
|
|
|
|||
80
README.md
80
README.md
|
|
@ -55,6 +55,16 @@ Before starting, you should have installed:
|
|||
least C++17.
|
||||
- `tar` for extracting archives from Kaggle.
|
||||
|
||||
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
|
||||
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
|
||||
command line with
|
||||
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
|
||||
|
||||
```sh
|
||||
winget install --id Kitware.CMake
|
||||
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
|
||||
```
|
||||
|
||||
### Step 1: Obtain model weights and tokenizer from Kaggle
|
||||
|
||||
Visit [the Gemma model page on
|
||||
|
|
@ -82,7 +92,7 @@ weights enable faster inference. In general, we recommend starting with the
|
|||
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
|
||||
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
> [!NOTE]
|
||||
> [!NOTE]
|
||||
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
|
||||
> get up and running.
|
||||
|
||||
|
|
@ -104,9 +114,14 @@ convenient directory location (e.g. the `build/` directory in this repo).
|
|||
|
||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||
runtime, create a build directory and generate the build files using `cmake`
|
||||
from the top-level project directory. For the 8-bit switched floating point
|
||||
weights (sfp), run cmake with no options:
|
||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||
re-running with a different setting, be sure to clean out the `build/` directory
|
||||
with `rm -rf build/*` (warning this will delete any other files in the `build/`
|
||||
directory.
|
||||
|
||||
For the 8-bit switched floating point weights (sfp), run cmake with no options:
|
||||
|
||||
#### Unix-like Platforms
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
|
@ -126,17 +141,18 @@ your weights, you can enter the `build/` directory and run `make` to build the
|
|||
`./gemma` executable:
|
||||
|
||||
```sh
|
||||
cd build
|
||||
make -j [number of parallel threads to use] gemma
|
||||
# Configure `build` directory
|
||||
cmake --preset make
|
||||
|
||||
# Build project using make
|
||||
cmake --build --preset make -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
Replace `[number of parallel threads to use]` with a number - the number of
|
||||
cores available on your system is a reasonable heuristic.
|
||||
|
||||
For example, `make -j4 gemma` will build using 4 threads. If this is successful,
|
||||
you should now have a `gemma` executable in the `build/` directory. If the
|
||||
`nproc` command is available, you can use `make -j$(nproc) gemma` as a
|
||||
reasonable default for the number of threads.
|
||||
cores available on your system is a reasonable heuristic. For example,
|
||||
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
|
||||
available, you can use `make -j$(nproc) gemma` as a reasonable default
|
||||
for the number of threads.
|
||||
|
||||
If you aren't sure of the right value for the `-j` flag, you can simply run
|
||||
`make gemma` instead and it should still build the `./gemma` executable.
|
||||
|
|
@ -145,6 +161,20 @@ If you aren't sure of the right value for the `-j` flag, you can simply run
|
|||
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||
> parallel threads to 1. Using a larger number may result in errors.
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
|
||||
|
||||
#### Windows
|
||||
|
||||
```sh
|
||||
# Configure `build` directory
|
||||
cmake --preset windows
|
||||
|
||||
# Build project using Visual Studio Build Tools
|
||||
cmake --build --preset windows -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||
|
||||
### Step 4: Run
|
||||
|
||||
You can now run `gemma` from inside the `build/` directory.
|
||||
|
|
@ -212,6 +242,21 @@ We're working on a python script to convert a standard model format to `.sbs`,
|
|||
and hope have it available in the next week or so. Follow [this
|
||||
issue](https://github.com/google/gemma.cpp/issues/11) for updates.
|
||||
|
||||
**What are some easy ways to make the model run faster?**
|
||||
|
||||
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
|
||||
2. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
and saving mode is **off**. For most laptops, the power saving modes get
|
||||
activated automatically if the computer is not plugged in.
|
||||
3. Close other unused cpu-intensive applications.
|
||||
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
cores get engaged.
|
||||
5. Experiment with the `--num_threads` argument value. Depending on the device,
|
||||
larger numbers don't always mean better performance.
|
||||
|
||||
We're also working on algorithmic and optimization approaches for faster
|
||||
inference, stay tuned.
|
||||
|
||||
## Usage
|
||||
|
||||
`gemma` has different usage modes, controlled by the verbosity flag.
|
||||
|
|
@ -247,7 +292,7 @@ max_tokens : 3072
|
|||
max_generated_tokens : 2048
|
||||
|
||||
*Usage*
|
||||
Enter an instruction and press enter (%Q quits).
|
||||
Enter an instruction and press enter (%C reset conversation, %Q quits).
|
||||
|
||||
*Examples*
|
||||
- Write an email to grandma thanking her for the cookies.
|
||||
|
|
@ -385,6 +430,17 @@ make -j [number of parallel threads to use] libgemma
|
|||
If this is successful, you should now have a `libgemma` library file in the
|
||||
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
|
||||
|
||||
## Independent Projects Using gemma.cpp
|
||||
|
||||
Some independent projects using gemma.cpp:
|
||||
|
||||
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
|
||||
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
|
||||
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
|
||||
|
||||
If you would like to have your project included, feel free to get in touch or
|
||||
submit a PR with a `README.md` edit.
|
||||
|
||||
## Acknowledgements and Contacts
|
||||
|
||||
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
|
||||
|
|
|
|||
|
|
@ -13,14 +13,32 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||
#undef _XOPEN_SOURCE
|
||||
#define _XOPEN_SOURCE 700
|
||||
#endif
|
||||
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
||||
#define _POSIX_C_SOURCE 200809
|
||||
#endif
|
||||
|
||||
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
||||
#undef _FILE_OFFSET_BITS
|
||||
#define _FILE_OFFSET_BITS 64
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/blob_store.h"
|
||||
|
||||
#include <fcntl.h> // open
|
||||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/stat.h> // O_RDONLY
|
||||
#include <unistd.h> // read, close
|
||||
#include <fcntl.h> // open
|
||||
#if HWY_OS_WIN
|
||||
#include <io.h> // read, write, close
|
||||
#include <fileapi.h>
|
||||
#else
|
||||
#include <unistd.h> // read, write, close
|
||||
#endif
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
|
@ -30,6 +48,54 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
|
||||
namespace {
|
||||
#if HWY_OS_WIN
|
||||
|
||||
// pread is not supported on Windows
|
||||
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_read;
|
||||
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
// pwrite is not supported on Windows
|
||||
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_written;
|
||||
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_written;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
hwy::uint128_t MakeKey(const char* string) {
|
||||
|
|
@ -64,19 +130,31 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
struct IO {
|
||||
// Returns size in bytes or 0.
|
||||
static uint64_t FileSize(const char* filename) {
|
||||
int fd = open(filename, O_RDONLY);
|
||||
if (fd >= 0) {
|
||||
const off_t size = lseek(fd, 0, SEEK_END);
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (size != static_cast<off_t>(-1)) {
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
if (fd < 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 0;
|
||||
#if HWY_OS_WIN
|
||||
const int64_t size = _lseeki64(fd, 0, SEEK_END);
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (size < 0) {
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
||||
const off_t size = lseek(fd, 0, SEEK_END);
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (size == static_cast<off_t>(-1)) {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
|
||||
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
|
||||
|
|
@ -252,10 +330,18 @@ class BlobStore {
|
|||
#pragma pack(pop)
|
||||
|
||||
BlobError BlobReader::Open(const char* filename) {
|
||||
#if HWY_OS_WIN
|
||||
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
|
||||
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
|
||||
OPEN_EXISTING, flags, nullptr);
|
||||
if (file == INVALID_HANDLE_VALUE) return __LINE__;
|
||||
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
|
||||
#else
|
||||
fd_ = open(filename, O_RDONLY);
|
||||
#endif
|
||||
if (fd_ < 0) return __LINE__;
|
||||
|
||||
#if _POSIX_C_SOURCE >= 200112L
|
||||
#if HWY_OS_LINUX
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif
|
||||
|
|
@ -330,7 +416,15 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
|||
keys_.data(), blobs_.data(), keys_.size());
|
||||
|
||||
// Create/replace existing file.
|
||||
#if HWY_OS_WIN
|
||||
DWORD flags = FILE_ATTRIBUTE_NORMAL;
|
||||
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS,
|
||||
flags, nullptr);
|
||||
if (file == INVALID_HANDLE_VALUE) return __LINE__;
|
||||
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
|
||||
#else
|
||||
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
|
||||
#endif
|
||||
if (fd < 0) return __LINE__;
|
||||
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
|
|
@ -341,6 +435,7 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
|||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ namespace gcpp {
|
|||
class DistortionStats {
|
||||
public:
|
||||
void Notify(float original, float distorted) {
|
||||
(void)padding_; // prevent unused member warning
|
||||
|
||||
const double l1 = hwy::ScalarAbs(original - distorted);
|
||||
|
||||
if (l1 > max_l1_) {
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ std::string Stats::ToString(int exclude) const {
|
|||
pos += ret;
|
||||
}
|
||||
|
||||
HWY_ASSERT(pos < sizeof(buf));
|
||||
HWY_ASSERT(pos < static_cast<int>(sizeof(buf)));
|
||||
return buf;
|
||||
}
|
||||
|
||||
|
|
|
|||
15
configs.h
15
configs.h
|
|
@ -18,11 +18,16 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
#endif // !GEMMA_MAX_SEQLEN
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr size_t kSeqLen = 7168;
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
|
||||
struct ConfigGemma7B {
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -31,8 +36,8 @@ struct ConfigGemma7B {
|
|||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = 1;
|
||||
};
|
||||
|
||||
|
|
@ -43,8 +48,8 @@ struct ConfigGemma2B {
|
|||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = 1;
|
||||
};
|
||||
|
||||
|
|
|
|||
20
gemma.cc
20
gemma.cc
|
|
@ -633,30 +633,32 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
c_weights.c_final_norm_scale);
|
||||
|
||||
char name[16];
|
||||
for (size_t layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
Layer<TConfig>* layer = weights ? &weights->layers[layer_idx] : nullptr;
|
||||
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer_idx);
|
||||
for (int layer_idx = 0; layer_idx < static_cast<int>(TConfig::kLayers);
|
||||
++layer_idx) {
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
Layer<TConfig>* layer = weights ? &weights->layers[idx] : nullptr;
|
||||
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(idx);
|
||||
|
||||
snprintf(name, sizeof(name), "pre_ff_ns_%lu", layer_idx);
|
||||
snprintf(name, sizeof(name), "pre_ff_ns_%d", layer_idx);
|
||||
func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr,
|
||||
c_layer->c_pre_ffw_norm_scale);
|
||||
|
||||
snprintf(name, sizeof(name), "gating_ein_%lu", layer_idx);
|
||||
snprintf(name, sizeof(name), "gating_ein_%d", layer_idx);
|
||||
func(name, layer ? layer->gating_einsum_w.data() : nullptr,
|
||||
c_layer->c_gating_einsum_w);
|
||||
|
||||
snprintf(name, sizeof(name), "linear_w_%lu", layer_idx);
|
||||
snprintf(name, sizeof(name), "linear_w_%d", layer_idx);
|
||||
func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w);
|
||||
snprintf(name, sizeof(name), "qkv_ein_%lu", layer_idx);
|
||||
snprintf(name, sizeof(name), "qkv_ein_%d", layer_idx);
|
||||
|
||||
func(name, layer ? layer->qkv_einsum_w.data() : nullptr,
|
||||
c_layer->c_qkv_einsum_w);
|
||||
snprintf(name, sizeof(name), "att_ein_%lu", layer_idx);
|
||||
snprintf(name, sizeof(name), "att_ein_%d", layer_idx);
|
||||
|
||||
func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr,
|
||||
c_layer->c_attn_vec_einsum_w);
|
||||
|
||||
snprintf(name, sizeof(name), "pre_att_ns_%lu", layer_idx);
|
||||
snprintf(name, sizeof(name), "pre_att_ns_%d", layer_idx);
|
||||
func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr,
|
||||
c_layer->c_pre_attention_norm_scale);
|
||||
}
|
||||
|
|
|
|||
29
gemma.h
29
gemma.h
|
|
@ -26,15 +26,19 @@
|
|||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h" // SfpStream/NuqStream
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "configs.h" // kSeqLen
|
||||
#include "configs.h" // kSeqLen
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // ArgsBase
|
||||
#include "util/args.h" // ArgsBase
|
||||
// copybara:end
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
// copybara:end
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -118,21 +122,22 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file. (required)");
|
||||
"Path name of tokenizer model file.\n Required argument.");
|
||||
visitor(
|
||||
cache, "compressed_weights", Path(),
|
||||
"Path name of compressed weights file, regenerated from `--weights` "
|
||||
"file if "
|
||||
"the compressed weights file does not exist. (required)");
|
||||
"the compressed weights file does not exist.\n Required argument.");
|
||||
visitor(model_type, "model", std::string(),
|
||||
"Model type - can be 2b-it (2B parameters, instruction-tuned), "
|
||||
"2b-pt (2B parameters, pretrained), 7b-it (7B parameters, "
|
||||
"instruction-tuned), or 7b-pt (7B parameters, pretrained). "
|
||||
"(required)");
|
||||
"Model type\n 2b-it (2B parameters, instruction-tuned)\n "
|
||||
"2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters "
|
||||
"instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n"
|
||||
" Required argument.");
|
||||
visitor(model, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file. Only required if "
|
||||
"compressed_weights file is not present and needs to be "
|
||||
"regenerated. Otherwise, not needed");
|
||||
"regenerated. This parameter is only required for compressing"
|
||||
"new model weight exports, otherwise it is not needed.");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -186,10 +191,10 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", true,
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
||||
"interaction without quitting)",
|
||||
2);
|
||||
"interaction without quitting)\n Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
55
ops.h
55
ops.h
|
|
@ -214,7 +214,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
|||
size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
hn::Transform(D(), x, size, [](D d, hn::Vec<D> v) { return Gelu(d, v); });
|
||||
hn::Transform(D(), x, size,
|
||||
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
|
||||
}
|
||||
|
||||
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
|
||||
|
|
@ -567,22 +568,41 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
|
|||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
const size_t N = hn::Lanes(d);
|
||||
|
||||
// Find max so we can subtract it below.
|
||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V max = vmin;
|
||||
hn::Foreach(d, x, mask_pos, vmin,
|
||||
[&max](D d, V v) { max = hn::Max(max, v); });
|
||||
max = hn::MaxOfLanes(d, max); // broadcast
|
||||
// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
|
||||
// cannot be lambda-captured.
|
||||
// TODO(janwas): could be replaced with an hn::Accumulate algo.
|
||||
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
hn::Vec<D> vmax = vmin;
|
||||
size_t idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
vmax = hn::Max(vmax, LoadU(d, x + idx));
|
||||
}
|
||||
}
|
||||
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
|
||||
vmax = hn::MaxOfLanes(d, vmax); // broadcast
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
V sum = hn::Zero(d);
|
||||
hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) {
|
||||
const V out = hn::Exp(d, hn::Sub(v, max));
|
||||
// Also avoid hn::Transform because the additional `sum` output vector cannot
|
||||
// be captured by a lambda.
|
||||
hn::Vec<D> sum = hn::Zero(d);
|
||||
idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
hn::StoreU(out, d, x + idx);
|
||||
}
|
||||
}
|
||||
if (mask_pos > idx) {
|
||||
const size_t remaining = mask_pos - idx;
|
||||
const hn::Vec<D> out =
|
||||
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
return out;
|
||||
});
|
||||
hn::StoreN(out, d, x + idx, remaining);
|
||||
}
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
|
|
@ -601,13 +621,12 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
|
||||
const V inv_cap = hn::Set(d, 1.0f / cap);
|
||||
const V vcap = hn::Set(d, cap);
|
||||
const float inv_cap = 1.0f / cap;
|
||||
|
||||
hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec<D> v) {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v)));
|
||||
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(hn::Set(d, cap),
|
||||
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
81
run.cc
81
run.cc
|
|
@ -24,12 +24,16 @@
|
|||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h" // Gemma
|
||||
#include "gemma.h" // Gemma
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // HasHelp
|
||||
// copybara:end
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -39,20 +43,13 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||
gcpp::AppArgs& app) {
|
||||
fprintf(stderr,
|
||||
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
|
||||
"specify 3 required model loading arguments: --tokenizer, "
|
||||
"--compressed_weights, "
|
||||
"and --model.\n\nModel Loading Arguments\n\n");
|
||||
loader.Help();
|
||||
fprintf(stderr, "\nInference Arguments\n\n");
|
||||
inference.Help();
|
||||
fprintf(stderr, "\nApplication Arguments\n\n");
|
||||
app.Help();
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
static constexpr std::string_view kAsciiArtBanner =
|
||||
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
|
||||
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
|
||||
"| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n"
|
||||
" \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n"
|
||||
" __/ | | | | |\n"
|
||||
" |___/ |_| |_|";
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
loader.Print(app.verbosity);
|
||||
|
|
@ -69,7 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
<< std::thread::hardware_concurrency() << std::endl
|
||||
<< "Instruction set : "
|
||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||
<< hwy::VectorBytes() * 8 << " bits)"
|
||||
<< "\n"
|
||||
<< "Weight Type : "
|
||||
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
|
||||
<< "EmbedderInput Type : "
|
||||
|
|
@ -77,9 +75,31 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
}
|
||||
|
||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||
gcpp::AppArgs& app) {
|
||||
std::cerr
|
||||
<< kAsciiArtBanner
|
||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"To run gemma.cpp, you need to "
|
||||
"specify 3 required model loading arguments:\n --tokenizer\n "
|
||||
"--compressed_weights\n"
|
||||
" --model.\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--compressed_weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n*Application Arguments*\n\n";
|
||||
app.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token) {
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -137,7 +157,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
if (verbosity >= 1) {
|
||||
std::cout << "> " << std::flush;
|
||||
}
|
||||
std::getline(std::cin, prompt_string);
|
||||
|
||||
if (eot_line.size() == 0) {
|
||||
std::getline(std::cin, prompt_string);
|
||||
} else {
|
||||
std::string line;
|
||||
while (std::getline(std::cin, line)) {
|
||||
if (line == eot_line) {
|
||||
break;
|
||||
}
|
||||
prompt_string += line + "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
||||
|
|
@ -221,7 +252,12 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
const std::string instructions =
|
||||
"*Usage*\n"
|
||||
" Enter an instruction and press enter (%Q quits).\n\n"
|
||||
" Enter an instruction and press enter (%C resets conversation, "
|
||||
"%Q quits).\n" +
|
||||
(inference.multiturn == 0
|
||||
? std::string(" Since multiturn is set to 0, conversation will "
|
||||
"automatically reset every turn.\n\n")
|
||||
: "\n") +
|
||||
"*Examples*\n"
|
||||
" - Write an email to grandma thanking her for the cookies.\n"
|
||||
" - What are some historical attractions to visit around "
|
||||
|
|
@ -230,13 +266,14 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
" - Write a standup comedy bit about GPU programming.\n";
|
||||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< banner_ascii_art << "\n\n";
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, inference, app);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
|
||||
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; });
|
||||
ReplGemma(
|
||||
model, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
28
util/app.h
28
util/app.h
|
|
@ -18,14 +18,20 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
|
||||
#if HWY_OS_LINUX
|
||||
#include <sched.h>
|
||||
|
||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
||||
#endif
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::clamp
|
||||
#include <thread> // NOLINT>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
// copybara:end
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -36,7 +42,13 @@ static inline void PinThreadToCore(size_t cpu_index) {
|
|||
cpu_set_t cset; // bit array
|
||||
CPU_ZERO(&cset); // clear all
|
||||
CPU_SET(cpu_index, &cset); // set bit indicating which processor to run on.
|
||||
HWY_ASSERT(0 == sched_setaffinity(0, sizeof(cset), &cset));
|
||||
const int err = sched_setaffinity(0, sizeof(cset), &cset);
|
||||
if (err != 0) {
|
||||
fprintf(stderr,
|
||||
"sched_setaffinity returned %d, errno %d. Can happen if running in "
|
||||
"a container; this warning is safe to ignore.\n",
|
||||
err, errno);
|
||||
}
|
||||
#else
|
||||
(void)cpu_index;
|
||||
#endif
|
||||
|
|
@ -62,10 +74,10 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
Path log; // output
|
||||
int verbosity;
|
||||
size_t num_threads;
|
||||
std::string eot_line;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(log, "log", Path{"/tmp/log.txt"}, "Logging file", 2);
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
|
|
@ -73,10 +85,16 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
2);
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
"Number of threads to use. Default value is set based on an "
|
||||
"estimate of "
|
||||
"how many concurrent threads are supported.",
|
||||
"Number of threads to use.\n Default = Estimate of the "
|
||||
"number of suupported concurrent threads.",
|
||||
2);
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ class ArgsBase {
|
|||
}
|
||||
};
|
||||
|
||||
static bool HasHelp(int argc, char* argv[]) {
|
||||
static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
|
||||
// TODO(austinvhuang): handle case insensitivity
|
||||
if (argc == 1) {
|
||||
// no arguments - print help
|
||||
|
|
|
|||
Loading…
Reference in New Issue