From 108820b795bd663a285401950a01a47c1dbcc7da Mon Sep 17 00:00:00 2001 From: Osamaali313 <86572800+Osamaali313@users.noreply.github.com> Date: Sun, 14 Jun 2026 22:32:38 +0300 Subject: [PATCH] fix: index CUDA sources/headers when generating .pyi stubs `build_cpp_function_index` (and `extract_m_def_statements`) only scanned `.cpp/.cc/.cxx/.c/.hpp/.h`, omitting `.cu` and `.cuh`. As a result, any pybind-bound function whose C++ definition/declaration lives in a CUDA source or header is not found, and the generated stub falls back to a generic `(*args, **kwargs) -> Any` (the code even logs "... not found in any .cpp file"). The omission is also internally inconsistent: `build_cpp_function_index` already classifies `.cuh` as a header via its `is_header` check, but never reads `.cuh` files, so that branch is unreachable. Add `.cu`/`.cuh` to both extension allowlists so CUDA-located bindings get a typed stub. No behavior change for the current tree (all 38 resolvable signatures still resolve; the only `.cu`, `csrc/indexing/main.cu`, registers no bindings); this matters as more kernels move into `.cu`/`.cuh`. --- scripts/generate_pyi.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/generate_pyi.py b/scripts/generate_pyi.py index df7490d410..138dbe71ec 100644 --- a/scripts/generate_pyi.py +++ b/scripts/generate_pyi.py @@ -4,7 +4,11 @@ def build_cpp_function_index(root_path): func_index = {} - extensions = {'.cpp', '.cc', '.cxx', '.c', '.hpp', '.h'} + # Include CUDA sources/headers: bound functions in this project may be + # defined in `.cu` or declared in `.cuh`, and the `is_header` check below + # already treats `.cuh` as a header. Without them such functions are missed + # and fall back to a generic `(*args, **kwargs) -> Any` stub. + extensions = {'.cpp', '.cc', '.cxx', '.c', '.cu', '.hpp', '.h', '.cuh'} pattern = re.compile( r'([\w:\s*<&>,\[\]\(\)]+?)' @@ -153,7 +157,7 @@ def extract_m_def_statements(root_path): Scan all c files under root_path and extract all m.def(...) statements. """ results = [] - extensions = {'.hpp', '.cpp', '.h', '.cc'} + extensions = {'.hpp', '.cpp', '.h', '.cc', '.cu', '.cuh'} # Regex: match m.def( ... ), supports multi-line pattern = re.compile(r'm\.def\s*\(')