llama.cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py

333 lines
12 KiB
Python
Executable File
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
# Generated by Claude AI
Script to completely regenerate the GGML remoting codebase from YAML configuration.
This script reads api_functions.yaml and regenerates all the header files and
implementation templates for the GGML remoting layer.
Usage:
python regenerate_remoting.py
The script will:
1. Read ggmlremoting_functions.yaml configuration
2. Generate updated header files
3. Generate implementation templates in dedicated files
4. Show a summary of what was generated
"""
import yaml
from typing import Dict, List, Any
from pathlib import Path
import os
import subprocess
import shutil
import logging
NL = '\n' # can't have f"{'\n'}" in f-strings
class RemotingCodebaseGenerator:
def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"):
"""Initialize the generator with the YAML configuration."""
self.yaml_path = yaml_path
if not Path(yaml_path).exists():
raise FileNotFoundError(f"Configuration file {yaml_path} not found")
with open(yaml_path, 'r') as f:
self.config = yaml.safe_load(f)
self.functions = self.config['functions']
self.naming_patterns = self.config['naming_patterns']
self.config_data = self.config['config']
# Check if clang-format is available
self.clang_format_available = self._check_clang_format_available()
def _check_clang_format_available(self) -> bool:
"""Check if clang-format is available in the system PATH."""
return shutil.which("clang-format") is not None
def _format_file_with_clang_format(self, file_path: Path) -> bool:
"""Format a file with clang-format -i. Returns True if successful, False otherwise."""
if not self.clang_format_available:
return False
try:
subprocess.run(
["clang-format", "-i", str(file_path)],
check=True,
capture_output=True,
text=True
)
return True
except subprocess.CalledProcessError:
logging.exception(f" ⚠️ clang-format failed for {file_path}")
return False
except Exception as e:
logging.exception(f" ⚠️ Unexpected error formatting {file_path}: {e}")
return False
def generate_enum_name(self, group_name: str, function_name: str) -> str:
"""Generate the APIR_COMMAND_TYPE enum name for a function."""
prefix = self.naming_patterns['enum_prefix']
return f"{prefix}{group_name.upper()}_{function_name.upper()}"
def generate_backend_function_name(self, group_name: str, function_name: str) -> str:
"""Generate the backend function name."""
function_key = f"{group_name}_{function_name}"
overrides = self.naming_patterns.get('backend_function_overrides', {})
if function_key in overrides:
return overrides[function_key]
prefix = self.naming_patterns['backend_function_prefix']
return f"{prefix}{group_name}_{function_name}"
def generate_frontend_function_name(self, group_name: str, function_name: str) -> str:
"""Generate the frontend function name."""
prefix = self.naming_patterns['frontend_function_prefix']
return f"{prefix}{group_name}_{function_name}"
def get_enabled_functions(self) -> List[Dict[str, Any]]:
"""Get all enabled functions with their metadata."""
functions = []
enum_value = 0
for group_name, group_data in self.functions.items():
group_description = group_data['group_description']
for function_name, func_metadata in group_data['functions'].items():
# Handle case where func_metadata is None or empty (functions with only comments)
if func_metadata is None:
func_metadata = {}
# Functions are enabled by default unless explicitly disabled
if func_metadata.get('enabled', True):
functions.append({
'group_name': group_name,
'function_name': function_name,
'enum_name': self.generate_enum_name(group_name, function_name),
'enum_value': enum_value,
'backend_function': self.generate_backend_function_name(group_name, function_name),
'frontend_function': self.generate_frontend_function_name(group_name, function_name),
'frontend_return': func_metadata.get('frontend_return', 'void'),
'frontend_extra_params': func_metadata.get('frontend_extra_params', []),
'group_description': group_description,
'deprecated': func_metadata.get('deprecated', False),
})
enum_value += 1
return functions
def generate_apir_backend_header(self) -> str:
"""Generate the complete apir_backend.h file."""
functions = self.get_enabled_functions()
# Generate the enum section
enum_lines = ["typedef enum ApirBackendCommandType {"]
current_group = None
for func in functions:
# Add comment for new group
if func['group_name'] != current_group:
enum_lines.append("")
enum_lines.append(f" /* {func['group_description']} */")
current_group = func['group_name']
enum_lines.append(f" {func['enum_name']} = {func['enum_value']},")
# Add the count
total_count = len(functions)
enum_lines.append("\n // last command_type index + 1")
enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
enum_lines.append("} ApirBackendCommandType;")
# Full header template
header_content = NL.join(enum_lines) + "\n"
return header_content
def generate_backend_dispatched_header(self) -> str:
"""Generate the complete backend-dispatched.h file."""
functions = self.get_enabled_functions()
# Function declarations
decl_lines = []
current_group = None
for func in functions:
if func['group_name'] != current_group:
decl_lines.append(f"\n/* {func['group_description']} */")
current_group = func['group_name']
signature = "uint32_t"
params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx"
if func['deprecated']:
decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */")
decl_lines.append(f"{signature} {func['backend_function']}({params});")
# Switch cases
switch_lines = []
current_group = None
for func in functions:
if func['group_name'] != current_group:
switch_lines.append(f" /* {func['group_description']} */")
current_group = func['group_name']
deprecated = " (DEPRECATED)" if func['deprecated'] else ""
switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";")
# Dispatch table
table_lines = []
current_group = None
for func in functions:
if func['group_name'] != current_group:
table_lines.append(f"\n /* {func['group_description']} */")
table_lines.append("")
current_group = func['group_name']
deprecated = " /* DEPRECATED */" if func['deprecated'] else ""
table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']}{deprecated},")
header_content = f'''\
#pragma once
{NL.join(decl_lines)}
static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
{{
switch (type) {{
{NL.join(switch_lines)}
default: return "unknown";
}}
}}
extern "C" {{
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
{NL.join(table_lines)}
}};
}}
'''
return header_content
def generate_virtgpu_forward_header(self) -> str:
"""Generate the complete virtgpu-forward.gen.h file."""
functions = self.get_enabled_functions()
decl_lines = []
current_group = None
for func in functions:
if func['group_name'] != current_group:
decl_lines.append("")
decl_lines.append(f"/* {func['group_description']} */")
current_group = func['group_name']
if func['deprecated']:
decl_lines.append(f"/* {func['frontend_function']} is deprecated. */")
continue
# Build parameter list
params = [self.naming_patterns['frontend_base_param']]
params.extend(func['frontend_extra_params'])
param_str = ', '.join(params)
decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});")
header_content = f'''\
#pragma once
{NL.join(decl_lines)}
'''
return header_content
def regenerate_codebase(self) -> None:
"""Regenerate the entire remoting codebase."""
logging.info("🔄 Regenerating GGML Remoting Codebase...")
logging.info("=" * 50)
# Detect if we're running from frontend directory
current_dir = os.getcwd()
is_frontend_dir = current_dir.endswith('ggml-virtgpu')
if is_frontend_dir:
# Running from ggml/src/ggml-virtgpu-apir
logging.info("📍 Detected frontend directory execution")
frontend_base = Path(".")
else:
# Running from project root (fallback to original behavior)
logging.info("📍 Detected project root execution")
base_path = self.config_data.get('base_path', 'ggml/src')
frontend_base = Path(base_path) / "ggml-virtgpu"
# Compute final file paths
backend_base = frontend_base / "backend"
apir_backend_path = backend_base / "shared" / "apir_backend.gen.h"
backend_dispatched_path = backend_base / "backend-dispatched.gen.h"
virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h"
# Create output directories for each file
apir_backend_path.parent.mkdir(parents=True, exist_ok=True)
backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True)
virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True)
# Generate header files
logging.info("📁 Generating header files...")
apir_backend_content = self.generate_apir_backend_header()
apir_backend_path.write_text(apir_backend_content)
logging.info(f"{apir_backend_path.resolve()}")
backend_dispatched_content = self.generate_backend_dispatched_header()
backend_dispatched_path.write_text(backend_dispatched_content)
logging.info(f"{backend_dispatched_path.resolve()}")
virtgpu_forward_content = self.generate_virtgpu_forward_header()
virtgpu_forward_path.write_text(virtgpu_forward_content)
logging.info(f"{virtgpu_forward_path.resolve()}")
# Format generated files with clang-format
generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path]
if not self.clang_format_available:
logging.warning("\nclang-format not found in PATH. Generated files will not be formatted.\n"
" Install clang-format to enable automatic code formatting.")
else:
logging.info("\n🎨 Formatting files with clang-format...")
for file_path in generated_files:
if self._format_file_with_clang_format(file_path):
logging.info(f" ✅ Formatted {file_path.name}")
else:
logging.warning(f" ❌ Failed to format {file_path.name}")
# Generate summary
functions = self.get_enabled_functions()
total_functions = len(functions)
logging.info("\n📊 Generation Summary:")
logging.info("=" * 50)
logging.info(f" Total functions: {total_functions}")
logging.info(f" Function groups: {len(self.functions)}")
logging.info(" Header files: 3")
logging.info(f" Working directory: {current_dir}")
def main():
try:
generator = RemotingCodebaseGenerator()
generator.regenerate_codebase()
except Exception as e:
logging.exception(f"❌ Error: {e}")
exit(1)
if __name__ == "__main__":
main()