diff --git a/.github/workflows/hip-quality-check.yml b/.github/workflows/hip-quality-check.yml index 04ae96d648..474c0ad415 100644 --- a/.github/workflows/hip-quality-check.yml +++ b/.github/workflows/hip-quality-check.yml @@ -8,7 +8,8 @@ on: paths: [ '.github/workflows/hip-quality-check.yml', '**/*.cu', - '**/*.cuh' + '**/*.cuh', + 'scripts/hip/gcn-cdna-vgpr-check.py' ] pull_request: @@ -16,7 +17,8 @@ on: paths: [ '.github/workflows/hip-quality-check.yml', '**/*.cu', - '**/*.cuh' + '**/*.cuh', + 'scripts/hip/gcn-cdna-vgpr-check.py' ] concurrency: diff --git a/scripts/hip/gcn-cdna-vgpr-check.py b/scripts/hip/gcn-cdna-vgpr-check.py index 934728d4a6..38db47d3d1 100644 --- a/scripts/hip/gcn-cdna-vgpr-check.py +++ b/scripts/hip/gcn-cdna-vgpr-check.py @@ -2,37 +2,51 @@ import sys from collections import defaultdict +import re def parse_log_file(filepath): - """Parse log file and extract function VGPR usage.""" - import re - functions = defaultdict(lambda: {'vgprs': 0, 'spill': 0, 'location': ''}) + func_stack = [] try: with open(filepath, 'r') as f: - content = f.read() - # Find all function entries with VGPR usage including location - pattern = r'([^:]+:\d+):.*?Function Name: (\S+).*?VGPRs: (\d+).*?VGPRs Spill: (\d+)' - matches = re.findall(pattern, content, re.DOTALL) + for line in f: + # Match function name lines + func_match = re.search(r'remark: ([^:]+):(\d+):\d+: Function Name: (\S+)', line) + if func_match: + location = func_match.group(1) + ':' + func_match.group(2) + func_name = func_match.group(3) + # Extract just the filename and line number + parts = location.split('/') + short_location = parts[-1] if len(parts) > 0 else location + functions[func_name]['location'] = short_location + # Push function onto stack with its location + func_stack.append({'name': func_name, 'location': location}) + continue - for location, func_name, vgprs, spill in matches: - functions[func_name]['vgprs'] = int(vgprs) - functions[func_name]['spill'] = int(spill) - # Extract just the filename and line number - parts = location.split('/') - if len(parts) > 0: - short_location = parts[-1] # Get last part (filename) - # Check if there's a line number after filename - if ':' in short_location: - functions[func_name]['location'] = short_location - else: - functions[func_name]['location'] = location - else: - functions[func_name]['location'] = location + # Match VGPR usage lines (only if we have functions in stack) + vgpr_match = re.search(r'remark: ([^:]+):(\d+):\d+:\s+VGPRs: (\d+)', line) + if vgpr_match: + location = vgpr_match.group(1) + ':' + vgpr_match.group(2) + # Find the most recent function with matching location + for i in range(len(func_stack) - 1, -1, -1): + if func_stack[i]['location'] == location: + functions[func_stack[i]['name']]['vgprs'] = int(vgpr_match.group(3)) + break + continue + + spill_match = re.search(r'remark: ([^:]+):(\d+):\d+:\s+VGPRs Spill: (\d+)', line) + if spill_match: + location = spill_match.group(1) + ':' + spill_match.group(2) + # Find the most recent function with matching location + for i in range(len(func_stack) - 1, -1, -1): + if func_stack[i]['location'] == location: + functions[func_stack[i]['name']]['spill'] = int(spill_match.group(3)) + break + continue except FileNotFoundError: - print(f"Error: File {filepath} not found", file=sys.stderr) # noqa: NP100 + print(f"Error: File {filepath} not found", file=sys.stderr) # noqa: NP100 sys.exit(1) return functions @@ -40,7 +54,7 @@ def parse_log_file(filepath): def main(): if len(sys.argv) < 2: - print("Usage: ./vgpr_check.py ", file=sys.stderr) # noqa: NP100 + print("Usage: ./vgpr_check.py ", file=sys.stderr) # noqa: NP100 sys.exit(1) log_file = sys.argv[1] @@ -123,6 +137,9 @@ def main(): '_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil', '_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil', '_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil', + '_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil', + '_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii', + '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii' } functions = parse_log_file(log_file) @@ -134,7 +151,7 @@ def main(): total_vgprs = int(data['vgprs']) + int(data['spill']) if total_vgprs > 256 and func_name in ignored and func_name not in printed_ignored: location = data.get('location', log_file) - print(f"{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) [IGNORED]") # noqa: NP100 + print(f"{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) [IGNORED]") # noqa: NP100 printed_ignored.add(func_name) # Then print new functions with issues in red @@ -146,7 +163,7 @@ def main(): # Print in red if not ignored color_code = "\033[91m" if func_name not in ignored else "" reset_code = "\033[0m" if func_name not in ignored else "" - print(f"{color_code}{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) {status}{reset_code}") # noqa: NP100 + print(f"{color_code}{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) {status}{reset_code}") # noqa: NP100 if func_name not in ignored: found_issues = True