diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index a2885b7bf..9a1fc821f 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -846,6 +846,36 @@ def generate_preambles(self, target): yield ("08_c_math", "#include ") +class GSLCallable(ScalarCallable): + @override + def with_types(self, + arg_id_to_dtype: Mapping[int | str, LoopyType], + clbl_inf_ctx: CallablesInferenceContext, + ) -> tuple[Self, CallablesInferenceContext]: + name = self.name + + for id in arg_id_to_dtype: + if not isinstance(id, int): + raise LoopyError(f"'{name}' can take only positional arguments") + + arg_num_to_dtype = {cast("int", id): t for id, t in arg_id_to_dtype.items()} + arg_num_to_dtype[-1] = arg_num_to_dtype[max(arg_num_to_dtype)] + + name_in_target = f"gsl_sf_{name}" + return (self.copy(name_in_target=name_in_target, + arg_id_to_dtype=constantdict(arg_num_to_dtype)), + clbl_inf_ctx) + + def generate_preambles(self, target): + # Base GSL math header + yield ("40_c_gsl_math", "#include ") + + # Function-specific headers for GSL special functions + if self.name == "hyperg_2F1": + # gsl_sf_hyperg_2F1 is declared in + yield ("41_c_gsl_sf_hyperg", "#include ") + + def get_c_callables(): """ Returns an instance of :class:`InKernelCallable` if the function @@ -866,6 +896,14 @@ def get_gnu_libc_callables(): func_ids = ["bessel_jn", "bessel_yn"] return {id_: GNULibcCallable(id_) for id_ in func_ids} + +def get_gsl_callables(): + # Support special functions from + # https://www.gnu.org/software/gsl/doc/html/specfunc.html + func_ids = ["hyperg_2F1"] + return {id_: GSLCallable(id_) for id_ in func_ids} + + # }}} @@ -1612,6 +1650,7 @@ class CWithGNULibcASTBuilder(CASTBuilder): def known_callables(self): callables = super().known_callables callables.update(get_gnu_libc_callables()) + callables.update(get_gsl_callables()) return callables