333 lines
12 KiB
Python
Executable File
333 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
# Generated by Claude AI
|
||
|
||
Script to completely regenerate the GGML remoting codebase from YAML configuration.
|
||
|
||
This script reads api_functions.yaml and regenerates all the header files and
|
||
implementation templates for the GGML remoting layer.
|
||
|
||
Usage:
|
||
python regenerate_remoting.py
|
||
|
||
The script will:
|
||
1. Read ggmlremoting_functions.yaml configuration
|
||
2. Generate updated header files
|
||
3. Generate implementation templates in dedicated files
|
||
4. Show a summary of what was generated
|
||
"""
|
||
|
||
import yaml
|
||
from typing import Dict, List, Any
|
||
from pathlib import Path
|
||
import os
|
||
import subprocess
|
||
import shutil
|
||
import logging
|
||
|
||
NL = '\n' # can't have f"{'\n'}" in f-strings
|
||
|
||
|
||
class RemotingCodebaseGenerator:
|
||
def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"):
|
||
"""Initialize the generator with the YAML configuration."""
|
||
self.yaml_path = yaml_path
|
||
|
||
if not Path(yaml_path).exists():
|
||
raise FileNotFoundError(f"Configuration file {yaml_path} not found")
|
||
|
||
with open(yaml_path, 'r') as f:
|
||
self.config = yaml.safe_load(f)
|
||
|
||
self.functions = self.config['functions']
|
||
self.naming_patterns = self.config['naming_patterns']
|
||
self.config_data = self.config['config']
|
||
|
||
# Check if clang-format is available
|
||
self.clang_format_available = self._check_clang_format_available()
|
||
|
||
def _check_clang_format_available(self) -> bool:
|
||
"""Check if clang-format is available in the system PATH."""
|
||
return shutil.which("clang-format") is not None
|
||
|
||
def _format_file_with_clang_format(self, file_path: Path) -> bool:
|
||
"""Format a file with clang-format -i. Returns True if successful, False otherwise."""
|
||
if not self.clang_format_available:
|
||
return False
|
||
|
||
try:
|
||
subprocess.run(
|
||
["clang-format", "-i", str(file_path)],
|
||
check=True,
|
||
capture_output=True,
|
||
text=True
|
||
)
|
||
return True
|
||
except subprocess.CalledProcessError:
|
||
logging.exception(f" ⚠️ clang-format failed for {file_path}")
|
||
return False
|
||
except Exception as e:
|
||
logging.exception(f" ⚠️ Unexpected error formatting {file_path}: {e}")
|
||
return False
|
||
|
||
def generate_enum_name(self, group_name: str, function_name: str) -> str:
|
||
"""Generate the APIR_COMMAND_TYPE enum name for a function."""
|
||
prefix = self.naming_patterns['enum_prefix']
|
||
return f"{prefix}{group_name.upper()}_{function_name.upper()}"
|
||
|
||
def generate_backend_function_name(self, group_name: str, function_name: str) -> str:
|
||
"""Generate the backend function name."""
|
||
function_key = f"{group_name}_{function_name}"
|
||
overrides = self.naming_patterns.get('backend_function_overrides', {})
|
||
|
||
if function_key in overrides:
|
||
return overrides[function_key]
|
||
|
||
prefix = self.naming_patterns['backend_function_prefix']
|
||
return f"{prefix}{group_name}_{function_name}"
|
||
|
||
def generate_frontend_function_name(self, group_name: str, function_name: str) -> str:
|
||
"""Generate the frontend function name."""
|
||
prefix = self.naming_patterns['frontend_function_prefix']
|
||
return f"{prefix}{group_name}_{function_name}"
|
||
|
||
def get_enabled_functions(self) -> List[Dict[str, Any]]:
|
||
"""Get all enabled functions with their metadata."""
|
||
functions = []
|
||
enum_value = 0
|
||
|
||
for group_name, group_data in self.functions.items():
|
||
group_description = group_data['group_description']
|
||
|
||
for function_name, func_metadata in group_data['functions'].items():
|
||
# Handle case where func_metadata is None or empty (functions with only comments)
|
||
if func_metadata is None:
|
||
func_metadata = {}
|
||
|
||
# Functions are enabled by default unless explicitly disabled
|
||
if func_metadata.get('enabled', True):
|
||
functions.append({
|
||
'group_name': group_name,
|
||
'function_name': function_name,
|
||
'enum_name': self.generate_enum_name(group_name, function_name),
|
||
'enum_value': enum_value,
|
||
'backend_function': self.generate_backend_function_name(group_name, function_name),
|
||
'frontend_function': self.generate_frontend_function_name(group_name, function_name),
|
||
'frontend_return': func_metadata.get('frontend_return', 'void'),
|
||
'frontend_extra_params': func_metadata.get('frontend_extra_params', []),
|
||
'group_description': group_description,
|
||
'deprecated': func_metadata.get('deprecated', False),
|
||
})
|
||
enum_value += 1
|
||
|
||
return functions
|
||
|
||
def generate_apir_backend_header(self) -> str:
|
||
"""Generate the complete apir_backend.h file."""
|
||
functions = self.get_enabled_functions()
|
||
|
||
# Generate the enum section
|
||
enum_lines = ["typedef enum ApirBackendCommandType {"]
|
||
current_group = None
|
||
|
||
for func in functions:
|
||
# Add comment for new group
|
||
if func['group_name'] != current_group:
|
||
enum_lines.append("")
|
||
enum_lines.append(f" /* {func['group_description']} */")
|
||
current_group = func['group_name']
|
||
|
||
enum_lines.append(f" {func['enum_name']} = {func['enum_value']},")
|
||
|
||
# Add the count
|
||
total_count = len(functions)
|
||
enum_lines.append("\n // last command_type index + 1")
|
||
enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
|
||
enum_lines.append("} ApirBackendCommandType;")
|
||
|
||
# Full header template
|
||
header_content = NL.join(enum_lines) + "\n"
|
||
|
||
return header_content
|
||
|
||
def generate_backend_dispatched_header(self) -> str:
|
||
"""Generate the complete backend-dispatched.h file."""
|
||
functions = self.get_enabled_functions()
|
||
|
||
# Function declarations
|
||
decl_lines = []
|
||
current_group = None
|
||
|
||
for func in functions:
|
||
if func['group_name'] != current_group:
|
||
decl_lines.append(f"\n/* {func['group_description']} */")
|
||
current_group = func['group_name']
|
||
|
||
signature = "uint32_t"
|
||
params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx"
|
||
if func['deprecated']:
|
||
decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */")
|
||
|
||
decl_lines.append(f"{signature} {func['backend_function']}({params});")
|
||
|
||
# Switch cases
|
||
switch_lines = []
|
||
current_group = None
|
||
|
||
for func in functions:
|
||
if func['group_name'] != current_group:
|
||
switch_lines.append(f" /* {func['group_description']} */")
|
||
current_group = func['group_name']
|
||
|
||
deprecated = " (DEPRECATED)" if func['deprecated'] else ""
|
||
|
||
switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";")
|
||
|
||
# Dispatch table
|
||
table_lines = []
|
||
current_group = None
|
||
|
||
for func in functions:
|
||
if func['group_name'] != current_group:
|
||
table_lines.append(f"\n /* {func['group_description']} */")
|
||
table_lines.append("")
|
||
current_group = func['group_name']
|
||
|
||
deprecated = " /* DEPRECATED */" if func['deprecated'] else ""
|
||
table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']}{deprecated},")
|
||
|
||
header_content = f'''\
|
||
#pragma once
|
||
|
||
{NL.join(decl_lines)}
|
||
|
||
static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
|
||
{{
|
||
switch (type) {{
|
||
{NL.join(switch_lines)}
|
||
|
||
default: return "unknown";
|
||
}}
|
||
}}
|
||
|
||
extern "C" {{
|
||
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
|
||
{NL.join(table_lines)}
|
||
}};
|
||
}}
|
||
'''
|
||
return header_content
|
||
|
||
def generate_virtgpu_forward_header(self) -> str:
|
||
"""Generate the complete virtgpu-forward.gen.h file."""
|
||
functions = self.get_enabled_functions()
|
||
|
||
decl_lines = []
|
||
current_group = None
|
||
|
||
for func in functions:
|
||
if func['group_name'] != current_group:
|
||
decl_lines.append("")
|
||
decl_lines.append(f"/* {func['group_description']} */")
|
||
current_group = func['group_name']
|
||
|
||
if func['deprecated']:
|
||
decl_lines.append(f"/* {func['frontend_function']} is deprecated. */")
|
||
continue
|
||
|
||
# Build parameter list
|
||
params = [self.naming_patterns['frontend_base_param']]
|
||
params.extend(func['frontend_extra_params'])
|
||
param_str = ', '.join(params)
|
||
|
||
decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});")
|
||
|
||
header_content = f'''\
|
||
#pragma once
|
||
{NL.join(decl_lines)}
|
||
'''
|
||
return header_content
|
||
|
||
def regenerate_codebase(self) -> None:
|
||
"""Regenerate the entire remoting codebase."""
|
||
logging.info("🔄 Regenerating GGML Remoting Codebase...")
|
||
logging.info("=" * 50)
|
||
|
||
# Detect if we're running from frontend directory
|
||
current_dir = os.getcwd()
|
||
is_frontend_dir = current_dir.endswith('ggml-virtgpu')
|
||
|
||
if is_frontend_dir:
|
||
# Running from ggml/src/ggml-virtgpu-apir
|
||
logging.info("📍 Detected frontend directory execution")
|
||
frontend_base = Path(".")
|
||
else:
|
||
# Running from project root (fallback to original behavior)
|
||
logging.info("📍 Detected project root execution")
|
||
base_path = self.config_data.get('base_path', 'ggml/src')
|
||
frontend_base = Path(base_path) / "ggml-virtgpu"
|
||
|
||
# Compute final file paths
|
||
backend_base = frontend_base / "backend"
|
||
apir_backend_path = backend_base / "shared" / "apir_backend.gen.h"
|
||
backend_dispatched_path = backend_base / "backend-dispatched.gen.h"
|
||
virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h"
|
||
|
||
# Create output directories for each file
|
||
apir_backend_path.parent.mkdir(parents=True, exist_ok=True)
|
||
backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True)
|
||
virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Generate header files
|
||
logging.info("📁 Generating header files...")
|
||
|
||
apir_backend_content = self.generate_apir_backend_header()
|
||
apir_backend_path.write_text(apir_backend_content)
|
||
logging.info(f" ✅ {apir_backend_path.resolve()}")
|
||
|
||
backend_dispatched_content = self.generate_backend_dispatched_header()
|
||
backend_dispatched_path.write_text(backend_dispatched_content)
|
||
logging.info(f" ✅ {backend_dispatched_path.resolve()}")
|
||
|
||
virtgpu_forward_content = self.generate_virtgpu_forward_header()
|
||
virtgpu_forward_path.write_text(virtgpu_forward_content)
|
||
logging.info(f" ✅ {virtgpu_forward_path.resolve()}")
|
||
|
||
# Format generated files with clang-format
|
||
generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path]
|
||
|
||
if not self.clang_format_available:
|
||
logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted.\n"
|
||
" Install clang-format to enable automatic code formatting.")
|
||
else:
|
||
logging.info("\n🎨 Formatting files with clang-format...")
|
||
for file_path in generated_files:
|
||
if self._format_file_with_clang_format(file_path):
|
||
logging.info(f" ✅ Formatted {file_path.name}")
|
||
else:
|
||
logging.warning(f" ❌ Failed to format {file_path.name}")
|
||
|
||
# Generate summary
|
||
functions = self.get_enabled_functions()
|
||
total_functions = len(functions)
|
||
|
||
logging.info("\n📊 Generation Summary:")
|
||
logging.info("=" * 50)
|
||
logging.info(f" Total functions: {total_functions}")
|
||
logging.info(f" Function groups: {len(self.functions)}")
|
||
logging.info(" Header files: 3")
|
||
logging.info(f" Working directory: {current_dir}")
|
||
|
||
|
||
def main():
|
||
try:
|
||
generator = RemotingCodebaseGenerator()
|
||
generator.regenerate_codebase()
|
||
except Exception as e:
|
||
logging.exception(f"❌ Error: {e}")
|
||
exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|