mirror of https://github.com/google/gemma.cpp.git
CL added
This commit is contained in:
parent
cf435b77f9
commit
27f23fdae0
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
#include "cl_FACADE.h"
|
||||
#include "cl_global_custom.h"
|
||||
#include "cl_inside.h"
|
||||
#include "cl_embedded.h"
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#define to_big_radix_2(var) pow(2,std::ceil(log2(var)));
|
||||
|
|
@ -28,7 +28,7 @@ private:
|
|||
const ma_uint64& core_size,
|
||||
const Args& ... args);
|
||||
|
||||
CL_INSIDE *CLS;
|
||||
cl_embed *CLS;
|
||||
[[nodiscard]]
|
||||
cl_float2* bit_reverse(float* data_array, const int& data_length_radix_2);
|
||||
void butterfly_stage_radix_2(cl_float2* data, const int& data_length_radix_2, float* data_out);
|
||||
|
|
|
|||
|
|
@ -1,341 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
class CL_INSIDE {
|
||||
public:
|
||||
std::string bit_reverse_STFT =
|
||||
" int reverseBits(int num, int radix_2_data) {\n"
|
||||
" int reversed = 0;\n"
|
||||
" for (int i = 0; i < radix_2_data; ++i) {\n"
|
||||
" reversed = (reversed << 1) | (num & 1);\n"
|
||||
" num >>= 1;\n"
|
||||
" }\n"
|
||||
" return reversed;\n"
|
||||
" }\n"
|
||||
" __kernel void entry_point(__global float2* frame, __global float2* out_frame, int radix_2)\n"
|
||||
" {\n"
|
||||
" int powed = (int)pow(2.0,radix_2);\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long id_quot = myid / powed;\n"
|
||||
" int id_rem = myid%powed;\n"
|
||||
" long calced_id = id_quot*powed + reverseBits(id_rem,radix_2);\n"
|
||||
" out_frame[myid].x = frame[calced_id].x;\n"
|
||||
" out_frame[myid].y = 0.0;\n"
|
||||
" }";
|
||||
|
||||
std::string bit_reverse =
|
||||
"int reverseBits(int num, int radix_2_data) {\n"
|
||||
" int reversed = 0;\n"
|
||||
" for (int i = 0; i < radix_2_data; ++i) {\n"
|
||||
" reversed = (reversed << 1) | (num & 1);\n"
|
||||
" num >>= 1;\n"
|
||||
" }\n"
|
||||
" return reversed;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel void entry_point(__global float* frame, __global float2* out_frame, int radix_2)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" out_frame[myid].x = frame[reverseBits(myid,radix_2)];\n"
|
||||
" out_frame[myid].y = 0.0;\n"
|
||||
"\n"
|
||||
"}";
|
||||
|
||||
std::string butterfly_stage =
|
||||
"typedef float2 cfloat;\n"
|
||||
"#define I ((cfloat)(0.0, 1.0))\n"
|
||||
"inline float real(cfloat a){\n"
|
||||
" return a.x;\n"
|
||||
"}\n"
|
||||
"inline float imag(cfloat a){\n"
|
||||
" return a.y;\n"
|
||||
"}\n"
|
||||
"inline float cmod(cfloat a){\n"
|
||||
" return (sqrt(a.x*a.x + a.y*a.y));\n"
|
||||
"}\n"
|
||||
"inline float carg(cfloat a){\n"
|
||||
" if(a.x > 0){\n"
|
||||
" return atan(a.y / a.x);\n"
|
||||
"\n"
|
||||
" }else if(a.x < 0 && a.y >= 0){\n"
|
||||
" return atan(a.y / a.x) + M_PI;\n"
|
||||
"\n"
|
||||
" }else if(a.x < 0 && a.y < 0){\n"
|
||||
" return atan(a.y / a.x) - M_PI;\n"
|
||||
"\n"
|
||||
" }else if(a.x == 0 && a.y > 0){\n"
|
||||
" return M_PI/2;\n"
|
||||
"\n"
|
||||
" }else if(a.x == 0 && a.y < 0){\n"
|
||||
" return -M_PI/2;\n"
|
||||
"\n"
|
||||
" }else{\n"
|
||||
" return 0;\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"inline cfloat cmult(cfloat a, cfloat b){\n"
|
||||
" return (cfloat)( a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"inline cfloat cdiv(cfloat a, cfloat b){\n"
|
||||
" return (cfloat)((a.x*b.x + a.y*b.y)/(b.x*b.x + b.y*b.y), (a.y*b.x - a.x*b.y)/(b.x*b.x + b.y*b.y));\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"cfloat twiddle(int high, int low)\n"
|
||||
"{\n"
|
||||
" cfloat temp;\n"
|
||||
" temp.x = cos(2.0*M_PI*((float)high/(float)low));\n"
|
||||
" temp.y = -1.0*sin(2.0*M_PI*((float)high/(float)low));\n"
|
||||
" return temp;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"long2 indexer(const long ID,const int stage)\n"
|
||||
"{\n"
|
||||
" long2 temp;\n"
|
||||
" temp.x = (ID%((long)pow(2.0,stage)))+(long)pow(2.0,stage+1)*(ID/(long)pow(2.0,stage));\n"
|
||||
" temp.y = temp.x+(long)pow(2.0,stage);\n"
|
||||
" return temp;\n"
|
||||
"}\n"
|
||||
"__kernel void entry_point(__global float2* in_frame, __global float2* out_frame, int radix_2, int stage)\n"
|
||||
"{\n"
|
||||
" long powed_stage = (long)pow(2.0,stage);\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long2 origin_pair=indexer(myid,stage);\n"
|
||||
" cfloat this_twiddle = twiddle(myid%powed_stage,powed_stage*2);\n"
|
||||
" this_twiddle = cmult(in_frame[origin_pair.y],this_twiddle);\n"
|
||||
" out_frame[origin_pair.x]=in_frame[origin_pair.x]+this_twiddle;\n"
|
||||
" out_frame[origin_pair.y]=in_frame[origin_pair.x]-this_twiddle;\n"
|
||||
"}";
|
||||
|
||||
std::string butterfly_STFT=
|
||||
"typedef float2 cfloat;\n"
|
||||
"\n"
|
||||
"#define I ((cfloat)(0.0, 1.0))\n"
|
||||
"\n"
|
||||
"inline float real(cfloat a){\n"
|
||||
" return a.x;\n"
|
||||
"}\n"
|
||||
"inline float imag(cfloat a){\n"
|
||||
" return a.y;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"inline float cmod(cfloat a){\n"
|
||||
" return (sqrt(a.x*a.x + a.y*a.y));\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"inline float carg(cfloat a){\n"
|
||||
" if(a.x > 0){\n"
|
||||
" return atan(a.y / a.x);\n"
|
||||
"\n"
|
||||
" }else if(a.x < 0 && a.y >= 0){\n"
|
||||
" return atan(a.y / a.x) + M_PI;\n"
|
||||
"\n"
|
||||
" }else if(a.x < 0 && a.y < 0){\n"
|
||||
" return atan(a.y / a.x) - M_PI;\n"
|
||||
"\n"
|
||||
" }else if(a.x == 0 && a.y > 0){\n"
|
||||
" return M_PI/2;\n"
|
||||
"\n"
|
||||
" }else if(a.x == 0 && a.y < 0){\n"
|
||||
" return -M_PI/2;\n"
|
||||
"\n"
|
||||
" }else{\n"
|
||||
" return 0;\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"inline cfloat cmult(cfloat a, cfloat b){\n"
|
||||
" return (cfloat)( a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"inline cfloat cdiv(cfloat a, cfloat b){\n"
|
||||
" return (cfloat)((a.x*b.x + a.y*b.y)/(b.x*b.x + b.y*b.y), (a.y*b.x - a.x*b.y)/(b.x*b.x + b.y*b.y));\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"cfloat twiddle(int high, int low)\n"
|
||||
"{\n"
|
||||
" cfloat temp;\n"
|
||||
" temp.x = cos(2.0*M_PI*((float)high/(float)low));\n"
|
||||
" temp.y = -1.0*sin(2.0*M_PI*((float)high/(float)low));\n"
|
||||
" return temp;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"long2 indexer(const long ID,const int stage)\n"
|
||||
"{\n"
|
||||
" long2 temp;\n"
|
||||
" temp.x = (ID%((long)pow(2.0,stage)))+(long)pow(2.0,stage+1)*(ID/(long)pow(2.0,stage));\n"
|
||||
" temp.y = temp.x+(long)pow(2.0,stage);\n"
|
||||
" return temp;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel void entry_point(__global float2* in_frame, __global float2* out_frame, int radix_2, int stage)\n"
|
||||
"{\n"
|
||||
" long powed_stage = (long)pow(2.0,stage);\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long2 origin_pair=indexer(myid,stage);\n"
|
||||
" cfloat this_twiddle = twiddle(myid%powed_stage,powed_stage*2);\n"
|
||||
" this_twiddle = cmult(in_frame[origin_pair.y],this_twiddle);\n"
|
||||
" out_frame[origin_pair.x]=in_frame[origin_pair.x]+this_twiddle;\n"
|
||||
" out_frame[origin_pair.y]=in_frame[origin_pair.x]-this_twiddle;\n"
|
||||
"}";
|
||||
std::string window_function =
|
||||
"inline float window_func(const int powed, const int index, const int window_size)\n"
|
||||
"{\n"
|
||||
" return (0.5 - 0.5*cos(2.0*M_PI*(float)index/(float)(window_size-1)));\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel void entry_point(__global float2* frame_in, __global float2* frame_out, int window_radix_2_size)\n"
|
||||
"{\n"
|
||||
" int powed = (int)pow(2.0,window_radix_2_size);\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" frame_out[myid].x = frame_in[myid].x*window_func(powed,myid%powed,powed);\n"
|
||||
" \n"
|
||||
"}\n";
|
||||
|
||||
std::string overlap =
|
||||
"__kernel void entry_point(__global float* frame_in, __global float2* frame_out, int window_frame,const int overlap_frame, int2 acc_able_frame, int front_side_zero_padding_size)\n"
|
||||
"{\n"
|
||||
"\n"
|
||||
" unsigned long myid = get_global_id(0);\n"
|
||||
" unsigned long quot =myid/window_frame;\n"
|
||||
" int rem = myid%window_frame;\n"
|
||||
" unsigned long my_index = quot*overlap_frame + rem;\n"
|
||||
" float will_write;\n"
|
||||
" unsigned long frame_limit = acc_able_frame.x*window_frame + acc_able_frame.y;\n"
|
||||
" \n"
|
||||
" will_write = frame_limit<=my_index?0:frame_in[my_index];\n"
|
||||
" will_write = rem<front_side_zero_padding_size?0:will_write;\n"
|
||||
"\n"
|
||||
" frame_out[myid].x=will_write;\n"
|
||||
" frame_out[myid].y=0;\n"
|
||||
"}\n"
|
||||
"\n";
|
||||
|
||||
std::string to_power =
|
||||
"inline float cmod(float2 a){\n"
|
||||
" return (sqrt(a.x*a.x + a.y*a.y));\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel void entry_point(__global float2* in_frame, __global float* out_frame, int origin_size)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long half_size = (long)origin_size / 2;\n"
|
||||
" long index = ((long)origin_size * ( myid / half_size )) + ( myid % half_size );\n"
|
||||
" float powered =cmod(in_frame[index]);\n"
|
||||
" powered = myid%half_size < 1?0:powered;\n"
|
||||
" out_frame[myid]=powered;\n"
|
||||
" \n"
|
||||
" \n"
|
||||
"}\n";
|
||||
|
||||
std::string to_three_band =
|
||||
"__kernel void entry_point(__global float* in_frame, __global float3* out_frame, int radix_2_half_size, int low_mid, int mid_high)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long my_index = myid*(long)radix_2_half_size;\n"
|
||||
" float temp_low=0.0;\n"
|
||||
" float temp_mid=0.0;\n"
|
||||
" float temp_high=0.0;\n"
|
||||
" for(int i=0; i<low_mid; ++i){\n"
|
||||
" temp_low+=in_frame[my_index+i];\n"
|
||||
" }\n"
|
||||
" for(int i=low_mid; i<mid_high; ++i){\n"
|
||||
" temp_mid+=in_frame[my_index+i];\n"
|
||||
" }\n"
|
||||
" for(int i=mid_high; i<radix_2_half_size; ++i){\n"
|
||||
" temp_high+=in_frame[my_index+i];\n"
|
||||
" }\n"
|
||||
" temp_low = 20.0*log10(temp_low/2000.0);\n"
|
||||
" temp_mid = 20.0*log10(temp_mid/2000.0);\n"
|
||||
" temp_high = 20.0*log10(temp_high/2000.0);\n"
|
||||
" out_frame[myid].xyz=(float3)(temp_low, temp_mid, temp_high);\n"
|
||||
"}\n";
|
||||
std::string DaC =
|
||||
"__kernel void entry_point(__global float* in_frame, __global float* out_frame)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" int2 my_index;\n"
|
||||
" my_index.x = (myid*2);\n"
|
||||
" my_index.y = my_index.x+1;\n"
|
||||
" out_frame[myid]= in_frame[my_index.x]+in_frame[my_index.y];\n"
|
||||
" \n"
|
||||
"}\n";
|
||||
std::string split_low_band =
|
||||
"__kernel void entry_point(__global float* in_frame, __global float* low_out, int radix_2_half_size, int low_mid, int padded_size)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long powed_limit = (long)pow(2.0,radix_2_half_size);\n"
|
||||
" int my_locale_index = myid % padded_size;\n"
|
||||
" int my_global_index = myid / padded_size;\n"
|
||||
" long my_index = powed_limit * my_global_index + my_locale_index;\n"
|
||||
" float for_write = my_locale_index>=low_mid?0:in_frame[my_index];\n"
|
||||
" low_out[myid]=for_write;\n"
|
||||
"}\n";
|
||||
std::string split_mid_band =
|
||||
"__kernel void entry_point(__global float* in_frame, __global float* mid_out, int radix_2_half_size, int low_mid, int mid_high, int padded_size)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long powed_limit = (int)pow(2.0,radix_2_half_size);\n"
|
||||
" int my_locale_index = (myid % padded_size);\n"
|
||||
" int my_global_index = myid / padded_size;\n"
|
||||
" long my_index = powed_limit*my_global_index + my_locale_index + low_mid;\n"
|
||||
" float for_write = my_locale_index>=(mid_high-low_mid)?0:in_frame[my_index];\n"
|
||||
" mid_out[myid]=for_write;\n"
|
||||
"}\n";
|
||||
std::string split_high_band =
|
||||
"__kernel void entry_point(__global float* in_frame, __global float* high_out, int radix_2_half_size, int mid_high, int padded_size)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" long powed_limit = (int)pow(2.0,radix_2_half_size);\n"
|
||||
" int my_locale_index = (myid % padded_size);\n"
|
||||
" int my_global_index = myid / padded_size;\n"
|
||||
" long my_index = powed_limit*my_global_index + my_locale_index + mid_high;\n"
|
||||
" float for_write = my_locale_index>=mid_high?0:in_frame[my_index];\n"
|
||||
" high_out[myid]=for_write;\n"
|
||||
"}\n";
|
||||
std::string integrate_DaC =
|
||||
"__kernel void entry_point(__global float* low_in, __global float* mid_in, __global float* high_in, __global float3* integ_out)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" integ_out[myid].x=low_in[myid];\n"
|
||||
" integ_out[myid].y=mid_in[myid];\n"
|
||||
" integ_out[myid].z=high_in[myid];\n"
|
||||
"}\n";
|
||||
std::string to_dbfs =
|
||||
"float dbfs(float powered, int window_origin_size, int added_size){\n"
|
||||
" float result = 10.0 * log10(pow(powered,2) / (1.0 * (float)window_origin_size*(float)added_size));\n"
|
||||
" \n"
|
||||
" return result+20.0;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"__kernel void entry_point(__global float3* in_frame, __global float3* out_frame, int window_radix_2, int low_size, int mid_size, int high_size)\n"
|
||||
"{\n"
|
||||
" long myid = get_global_id(0);\n"
|
||||
" out_frame[myid].x = dbfs(in_frame[myid].x,window_radix_2,low_size);\n"
|
||||
" out_frame[myid].y = dbfs(in_frame[myid].y,window_radix_2,mid_size);\n"
|
||||
" out_frame[myid].z = dbfs(in_frame[myid].z,window_radix_2,high_size);\n"
|
||||
" \n"
|
||||
" \n"
|
||||
"}\n";
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -41,7 +41,9 @@ set(SOURCES
|
|||
compression/sfp-inl.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
)
|
||||
CL_Wrapper/
|
||||
CL_C_kernel_files/cl_embedded.h
|
||||
)
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
|
|
@ -64,6 +66,11 @@ add_executable(gemma run.cc)
|
|||
target_sources(gemma PRIVATE ${SOURCES})
|
||||
set_property(TARGET gemma PROPERTY CXX_STANDARD 17)
|
||||
target_link_libraries(gemma hwy hwy_contrib sentencepiece)
|
||||
if(WIN32)
|
||||
TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/win_x64_SDK/lib/OpenCL.lib)
|
||||
else()
|
||||
TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/linux/lib/OpenCL.a)
|
||||
endif()
|
||||
target_include_directories(gemma PRIVATE ./)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
|
|
@ -78,6 +85,11 @@ set_target_properties(libgemma PROPERTIES PREFIX "")
|
|||
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(libgemma PUBLIC ./)
|
||||
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
||||
if(WIN32)
|
||||
TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/win_x64_SDK/lib/OpenCL.lib)
|
||||
else()
|
||||
TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/linux/lib/OpenCL.a)
|
||||
endif()
|
||||
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
|
|
|||
Loading…
Reference in New Issue