diff --git a/CL_Wrapper/STFT.h b/CL_Wrapper/STFT.h index 8c6d8ba..d7b64da 100644 --- a/CL_Wrapper/STFT.h +++ b/CL_Wrapper/STFT.h @@ -1,7 +1,7 @@ #pragma once #include "cl_FACADE.h" #include "cl_global_custom.h" -#include "cl_inside.h" +#include "cl_embedded.h" #include #include #define to_big_radix_2(var) pow(2,std::ceil(log2(var))); @@ -28,7 +28,7 @@ private: const ma_uint64& core_size, const Args& ... args); - CL_INSIDE *CLS; + cl_embed *CLS; [[nodiscard]] cl_float2* bit_reverse(float* data_array, const int& data_length_radix_2); void butterfly_stage_radix_2(cl_float2* data, const int& data_length_radix_2, float* data_out); diff --git a/CL_Wrapper/cl_inside.h b/CL_Wrapper/cl_inside.h deleted file mode 100644 index e8e891d..0000000 --- a/CL_Wrapper/cl_inside.h +++ /dev/null @@ -1,341 +0,0 @@ -#pragma once - -#include -class CL_INSIDE { -public: - std::string bit_reverse_STFT = - " int reverseBits(int num, int radix_2_data) {\n" - " int reversed = 0;\n" - " for (int i = 0; i < radix_2_data; ++i) {\n" - " reversed = (reversed << 1) | (num & 1);\n" - " num >>= 1;\n" - " }\n" - " return reversed;\n" - " }\n" - " __kernel void entry_point(__global float2* frame, __global float2* out_frame, int radix_2)\n" - " {\n" - " int powed = (int)pow(2.0,radix_2);\n" - " long myid = get_global_id(0);\n" - " long id_quot = myid / powed;\n" - " int id_rem = myid%powed;\n" - " long calced_id = id_quot*powed + reverseBits(id_rem,radix_2);\n" - " out_frame[myid].x = frame[calced_id].x;\n" - " out_frame[myid].y = 0.0;\n" - " }"; - - std::string bit_reverse = - "int reverseBits(int num, int radix_2_data) {\n" - " int reversed = 0;\n" - " for (int i = 0; i < radix_2_data; ++i) {\n" - " reversed = (reversed << 1) | (num & 1);\n" - " num >>= 1;\n" - " }\n" - " return reversed;\n" - "}\n" - "\n" - "__kernel void entry_point(__global float* frame, __global float2* out_frame, int radix_2)\n" - "{\n" - " long myid = get_global_id(0);\n" - " out_frame[myid].x = frame[reverseBits(myid,radix_2)];\n" - " out_frame[myid].y = 0.0;\n" - "\n" - "}"; - - std::string butterfly_stage = - "typedef float2 cfloat;\n" - "#define I ((cfloat)(0.0, 1.0))\n" - "inline float real(cfloat a){\n" - " return a.x;\n" - "}\n" - "inline float imag(cfloat a){\n" - " return a.y;\n" - "}\n" - "inline float cmod(cfloat a){\n" - " return (sqrt(a.x*a.x + a.y*a.y));\n" - "}\n" - "inline float carg(cfloat a){\n" - " if(a.x > 0){\n" - " return atan(a.y / a.x);\n" - "\n" - " }else if(a.x < 0 && a.y >= 0){\n" - " return atan(a.y / a.x) + M_PI;\n" - "\n" - " }else if(a.x < 0 && a.y < 0){\n" - " return atan(a.y / a.x) - M_PI;\n" - "\n" - " }else if(a.x == 0 && a.y > 0){\n" - " return M_PI/2;\n" - "\n" - " }else if(a.x == 0 && a.y < 0){\n" - " return -M_PI/2;\n" - "\n" - " }else{\n" - " return 0;\n" - " }\n" - "}\n" - "inline cfloat cmult(cfloat a, cfloat b){\n" - " return (cfloat)( a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x);\n" - "}\n" - "\n" - "\n" - "inline cfloat cdiv(cfloat a, cfloat b){\n" - " return (cfloat)((a.x*b.x + a.y*b.y)/(b.x*b.x + b.y*b.y), (a.y*b.x - a.x*b.y)/(b.x*b.x + b.y*b.y));\n" - "}\n" - "\n" - "\n" - "cfloat twiddle(int high, int low)\n" - "{\n" - " cfloat temp;\n" - " temp.x = cos(2.0*M_PI*((float)high/(float)low));\n" - " temp.y = -1.0*sin(2.0*M_PI*((float)high/(float)low));\n" - " return temp;\n" - "}\n" - "\n" - "\n" - "long2 indexer(const long ID,const int stage)\n" - "{\n" - " long2 temp;\n" - " temp.x = (ID%((long)pow(2.0,stage)))+(long)pow(2.0,stage+1)*(ID/(long)pow(2.0,stage));\n" - " temp.y = temp.x+(long)pow(2.0,stage);\n" - " return temp;\n" - "}\n" - "__kernel void entry_point(__global float2* in_frame, __global float2* out_frame, int radix_2, int stage)\n" - "{\n" - " long powed_stage = (long)pow(2.0,stage);\n" - " long myid = get_global_id(0);\n" - " long2 origin_pair=indexer(myid,stage);\n" - " cfloat this_twiddle = twiddle(myid%powed_stage,powed_stage*2);\n" - " this_twiddle = cmult(in_frame[origin_pair.y],this_twiddle);\n" - " out_frame[origin_pair.x]=in_frame[origin_pair.x]+this_twiddle;\n" - " out_frame[origin_pair.y]=in_frame[origin_pair.x]-this_twiddle;\n" - "}"; - - std::string butterfly_STFT= - "typedef float2 cfloat;\n" - "\n" - "#define I ((cfloat)(0.0, 1.0))\n" - "\n" - "inline float real(cfloat a){\n" - " return a.x;\n" - "}\n" - "inline float imag(cfloat a){\n" - " return a.y;\n" - "}\n" - "\n" - "inline float cmod(cfloat a){\n" - " return (sqrt(a.x*a.x + a.y*a.y));\n" - "}\n" - "\n" - "inline float carg(cfloat a){\n" - " if(a.x > 0){\n" - " return atan(a.y / a.x);\n" - "\n" - " }else if(a.x < 0 && a.y >= 0){\n" - " return atan(a.y / a.x) + M_PI;\n" - "\n" - " }else if(a.x < 0 && a.y < 0){\n" - " return atan(a.y / a.x) - M_PI;\n" - "\n" - " }else if(a.x == 0 && a.y > 0){\n" - " return M_PI/2;\n" - "\n" - " }else if(a.x == 0 && a.y < 0){\n" - " return -M_PI/2;\n" - "\n" - " }else{\n" - " return 0;\n" - " }\n" - "}\n" - "\n" - "inline cfloat cmult(cfloat a, cfloat b){\n" - " return (cfloat)( a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x);\n" - "}\n" - "\n" - "\n" - "inline cfloat cdiv(cfloat a, cfloat b){\n" - " return (cfloat)((a.x*b.x + a.y*b.y)/(b.x*b.x + b.y*b.y), (a.y*b.x - a.x*b.y)/(b.x*b.x + b.y*b.y));\n" - "}\n" - "\n" - "\n" - "\n" - "cfloat twiddle(int high, int low)\n" - "{\n" - " cfloat temp;\n" - " temp.x = cos(2.0*M_PI*((float)high/(float)low));\n" - " temp.y = -1.0*sin(2.0*M_PI*((float)high/(float)low));\n" - " return temp;\n" - "}\n" - "\n" - "\n" - "long2 indexer(const long ID,const int stage)\n" - "{\n" - " long2 temp;\n" - " temp.x = (ID%((long)pow(2.0,stage)))+(long)pow(2.0,stage+1)*(ID/(long)pow(2.0,stage));\n" - " temp.y = temp.x+(long)pow(2.0,stage);\n" - " return temp;\n" - "}\n" - "\n" - "__kernel void entry_point(__global float2* in_frame, __global float2* out_frame, int radix_2, int stage)\n" - "{\n" - " long powed_stage = (long)pow(2.0,stage);\n" - " long myid = get_global_id(0);\n" - " long2 origin_pair=indexer(myid,stage);\n" - " cfloat this_twiddle = twiddle(myid%powed_stage,powed_stage*2);\n" - " this_twiddle = cmult(in_frame[origin_pair.y],this_twiddle);\n" - " out_frame[origin_pair.x]=in_frame[origin_pair.x]+this_twiddle;\n" - " out_frame[origin_pair.y]=in_frame[origin_pair.x]-this_twiddle;\n" - "}"; - std::string window_function = - "inline float window_func(const int powed, const int index, const int window_size)\n" - "{\n" - " return (0.5 - 0.5*cos(2.0*M_PI*(float)index/(float)(window_size-1)));\n" - "}\n" - "\n" - "__kernel void entry_point(__global float2* frame_in, __global float2* frame_out, int window_radix_2_size)\n" - "{\n" - " int powed = (int)pow(2.0,window_radix_2_size);\n" - " long myid = get_global_id(0);\n" - " frame_out[myid].x = frame_in[myid].x*window_func(powed,myid%powed,powed);\n" - " \n" - "}\n"; - - std::string overlap = - "__kernel void entry_point(__global float* frame_in, __global float2* frame_out, int window_frame,const int overlap_frame, int2 acc_able_frame, int front_side_zero_padding_size)\n" - "{\n" - "\n" - " unsigned long myid = get_global_id(0);\n" - " unsigned long quot =myid/window_frame;\n" - " int rem = myid%window_frame;\n" - " unsigned long my_index = quot*overlap_frame + rem;\n" - " float will_write;\n" - " unsigned long frame_limit = acc_able_frame.x*window_frame + acc_able_frame.y;\n" - " \n" - " will_write = frame_limit<=my_index?0:frame_in[my_index];\n" - " will_write = rem=low_mid?0:in_frame[my_index];\n" - " low_out[myid]=for_write;\n" - "}\n"; - std::string split_mid_band = - "__kernel void entry_point(__global float* in_frame, __global float* mid_out, int radix_2_half_size, int low_mid, int mid_high, int padded_size)\n" - "{\n" - " long myid = get_global_id(0);\n" - " long powed_limit = (int)pow(2.0,radix_2_half_size);\n" - " int my_locale_index = (myid % padded_size);\n" - " int my_global_index = myid / padded_size;\n" - " long my_index = powed_limit*my_global_index + my_locale_index + low_mid;\n" - " float for_write = my_locale_index>=(mid_high-low_mid)?0:in_frame[my_index];\n" - " mid_out[myid]=for_write;\n" - "}\n"; - std::string split_high_band = - "__kernel void entry_point(__global float* in_frame, __global float* high_out, int radix_2_half_size, int mid_high, int padded_size)\n" - "{\n" - " long myid = get_global_id(0);\n" - " long powed_limit = (int)pow(2.0,radix_2_half_size);\n" - " int my_locale_index = (myid % padded_size);\n" - " int my_global_index = myid / padded_size;\n" - " long my_index = powed_limit*my_global_index + my_locale_index + mid_high;\n" - " float for_write = my_locale_index>=mid_high?0:in_frame[my_index];\n" - " high_out[myid]=for_write;\n" - "}\n"; - std::string integrate_DaC = - "__kernel void entry_point(__global float* low_in, __global float* mid_in, __global float* high_in, __global float3* integ_out)\n" - "{\n" - " long myid = get_global_id(0);\n" - " integ_out[myid].x=low_in[myid];\n" - " integ_out[myid].y=mid_in[myid];\n" - " integ_out[myid].z=high_in[myid];\n" - "}\n"; - std::string to_dbfs = - "float dbfs(float powered, int window_origin_size, int added_size){\n" - " float result = 10.0 * log10(pow(powered,2) / (1.0 * (float)window_origin_size*(float)added_size));\n" - " \n" - " return result+20.0;\n" - "}\n" - "\n" - "__kernel void entry_point(__global float3* in_frame, __global float3* out_frame, int window_radix_2, int low_size, int mid_size, int high_size)\n" - "{\n" - " long myid = get_global_id(0);\n" - " out_frame[myid].x = dbfs(in_frame[myid].x,window_radix_2,low_size);\n" - " out_frame[myid].y = dbfs(in_frame[myid].y,window_radix_2,mid_size);\n" - " out_frame[myid].z = dbfs(in_frame[myid].z,window_radix_2,high_size);\n" - " \n" - " \n" - "}\n"; -}; - - - - - - - - - - - - - - diff --git a/CMakeLists.txt b/CMakeLists.txt index 308e258..f47856c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,7 +41,9 @@ set(SOURCES compression/sfp-inl.h util/app.h util/args.h - ) + CL_Wrapper/ + CL_C_kernel_files/cl_embedded.h +) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") @@ -64,6 +66,11 @@ add_executable(gemma run.cc) target_sources(gemma PRIVATE ${SOURCES}) set_property(TARGET gemma PROPERTY CXX_STANDARD 17) target_link_libraries(gemma hwy hwy_contrib sentencepiece) +if(WIN32) + TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/win_x64_SDK/lib/OpenCL.lib) +else() + TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/linux/lib/OpenCL.a) +endif() target_include_directories(gemma PRIVATE ./) FetchContent_GetProperties(sentencepiece) target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR}) @@ -78,6 +85,11 @@ set_target_properties(libgemma PROPERTIES PREFIX "") set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(libgemma PUBLIC ./) target_link_libraries(libgemma hwy hwy_contrib sentencepiece) +if(WIN32) + TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/win_x64_SDK/lib/OpenCL.lib) +else() + TARGET_LINK_LIBRARIES( gemma ${CMAKE_SOURCE_DIR}/Khronos_CL_SDK/linux/lib/OpenCL.a) +endif() target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>)