diff --git a/ggml/src/ggml-opencl/kernels/embed_kernel.py b/ggml/src/ggml-opencl/kernels/embed_kernel.py index b5d1d7242b..81d9868a4b 100644 --- a/ggml/src/ggml-opencl/kernels/embed_kernel.py +++ b/ggml/src/ggml-opencl/kernels/embed_kernel.py @@ -2,8 +2,25 @@ import sys import logging +import re +import os + logger = logging.getLogger("opencl-embed-kernel") +INCLUDE_PATTERN = re.compile(r'#include\s+"(.*)"') + + +def parse_file_line(ifile, ofile, base_path: str): + for i in ifile: + i = i.rstrip() + if m := INCLUDE_PATTERN.match(i): + include_file = os.path.join(base_path, m.group(1)) + logger.info(f"Embedding file: {include_file}") + with open(include_file, "r") as incf: + parse_file_line(incf, ofile, base_path) + else: + ofile.write('R"({})"\n'.format(i)) + def main(): logging.basicConfig(level=logging.INFO) @@ -12,14 +29,9 @@ def main(): logger.info("Usage: python embed_kernel.py ") sys.exit(1) - ifile = open(sys.argv[1], "r") - ofile = open(sys.argv[2], "w") - - for i in ifile: - ofile.write('R"({})"\n'.format(i)) - - ifile.close() - ofile.close() + ipath = os.path.dirname(sys.argv[1]) + with open(sys.argv[1], "r") as ifile, open(sys.argv[2], "w") as ofile: + parse_file_line(ifile, ofile, ipath) if __name__ == "__main__":