Use fp32 in cuBLAS V100 to avoid overflows, env variables to override cuBLAS compute type (#19959)
* Update ggml-cuda.cu * Update ggml-cuda.cu * Update build.md * Update build.md * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml-cuda.cu * Update build.md * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update build.md * Update ggml-cuda.cu * Update ggml-cuda.cu --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
9789c4ecdc
commit
f2c0dfb739
|
|
@ -269,6 +269,14 @@ The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cu
|
|||
|
||||
Consider setting `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs.
|
||||
|
||||
#### GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F
|
||||
|
||||
Use `GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F` environment variable to use FP32 compute type on all GPUs in FP16 cuBLAS for preventing possible numerical overflows in exchange for slower prompt processing (small impact on RTX PRO/Datacenter products and significant on GeForce products).
|
||||
|
||||
#### GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F
|
||||
|
||||
Use `GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F` environment variable to force use FP16 compute type (instead of default FP32) in FP16 cuBLAS for V100, CDNA and RDNA4.
|
||||
|
||||
### Unified Memory
|
||||
|
||||
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
|
||||
|
|
@ -280,7 +288,7 @@ The following compilation options are also available to tweak performance:
|
|||
| Option | Legal values | Default | Description |
|
||||
|-------------------------------|------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
|
||||
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models. There may be issues with numerical overflows (except for CDNA and RDNA4) and memory use will be higher. Prompt processing may become faster on recent datacenter GPUs (the custom kernels were tuned primarily for RTX 3000/4000). |
|
||||
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models. There may be issues with numerical overflows (except for V100, CDNA and RDNA4 which use FP32 compute type by default) and memory use will be higher. Prompt processing may become faster on recent datacenter GPUs (the custom kernels were tuned primarily for RTX 3000/4000). |
|
||||
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
|
||||
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
|
||||
|
||||
|
|
|
|||
|
|
@ -1242,6 +1242,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
}
|
||||
}
|
||||
|
||||
struct cublas_force_compute_type {
|
||||
bool fp32 = false;
|
||||
bool fp16 = false;
|
||||
};
|
||||
|
||||
static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
|
||||
static const cublas_force_compute_type compute_type = [] {
|
||||
cublas_force_compute_type result;
|
||||
|
||||
const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
|
||||
const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
|
||||
|
||||
GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
|
||||
|
||||
if (ggml_cuda_force_cublas_compute_32f_env) {
|
||||
GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
|
||||
result.fp32 = true;
|
||||
} else if (ggml_cuda_force_cublas_compute_16f_env) {
|
||||
GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
|
||||
result.fp16 = true;
|
||||
}
|
||||
|
||||
return result;
|
||||
}();
|
||||
|
||||
return compute_type;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_mul_mat_cublas(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
|
|
@ -1324,7 +1352,13 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
|
||||
|
||||
if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
|
||||
|| GGML_CUDA_CC_IS_RDNA4(cc)
|
||||
|| cc == GGML_CUDA_CC_VOLTA
|
||||
|| force_compute_type.fp32))
|
||||
{
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
CUBLAS_CHECK(
|
||||
|
|
@ -1923,10 +1957,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|||
cudaDataType_t cu_data_type_b = traits::data_type;
|
||||
const void * alpha = traits::get_alpha();
|
||||
const void * beta = traits::get_beta();
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
|
||||
|
||||
// bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
|
||||
// so checking necessity of forced fp32 only for fp16 src0_type
|
||||
static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
|
||||
|
||||
const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
|
||||
|| GGML_CUDA_CC_IS_RDNA4(cc)
|
||||
|| cc == GGML_CUDA_CC_VOLTA
|
||||
|| force_compute_type.fp32);
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
|
||||
if constexpr (src0_type == GGML_TYPE_F32) {
|
||||
dst_t = (char *) dst_ddf; // Direct F32 output
|
||||
} else {
|
||||
|
|
@ -1936,18 +1983,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|||
}
|
||||
} else {
|
||||
dst_t = (char *) dst_ddf;
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
cu_data_type = CUDA_R_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;
|
||||
cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;
|
||||
alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();
|
||||
beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue