diff --git a/src/shambackends/include/shambackends/kernel_call.hpp b/src/shambackends/include/shambackends/kernel_call.hpp index 95ace0b3a..fab7ece7d 100644 --- a/src/shambackends/include/shambackends/kernel_call.hpp +++ b/src/shambackends/include/shambackends/kernel_call.hpp @@ -316,8 +316,50 @@ namespace sham { in_out.complete_event_state(e); } + template + struct expected_kernel_signature; + + template + struct expected_kernel_signature {}; + + template + struct tuple_to_signature; + + template + struct tuple_to_signature> { + using type = void(Ts...); + }; + + template + struct kernel_gen_args { + + using args_types = decltype(std::tuple_cat( + std::tuple{}, + + std::declval().get_read_access( + std::declval()))>(), + + std::declval().get_write_access( + std::declval()))>())); + }; + + template + using kernel_expected_signature = expected_kernel_signature::args_types>::type>; + + template + struct is_kernel_invocable; + + template + struct is_kernel_invocable> + : std::bool_constant> {}; + + template + concept kernel_invocable = is_kernel_invocable::value; + /// internal implementation of typed_index_kernel_call template + requires kernel_invocable> void typed_index_kernel_call( sham::DeviceQueue &q, RefIn in, diff --git a/src/tests/shambackends/kernel_call_tests.cpp b/src/tests/shambackends/kernel_call_tests.cpp index dc333a3f1..c655fa2b4 100644 --- a/src/tests/shambackends/kernel_call_tests.cpp +++ b/src/tests/shambackends/kernel_call_tests.cpp @@ -65,7 +65,18 @@ TestStart(Unittest, "shambackends/kernel_call", testing_func_kernel_call_base, 1 P[i] = r; cs[i] = u; }); + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{rho_field_const, uint_field_const}, + sham::MultiRef{P_field, cs_field}, + size, + [](u32 i, const T *__restrict rho, T *__restrict U, T *__restrict P, T *__restrict cs) { + T r = rho[i]; + T u = U[i]; + P[i] = r; + cs[i] = u; + }); REQUIRE_EQUAL(P_field.copy_to_stdvec(), P_ref); REQUIRE_EQUAL(cs_field.copy_to_stdvec(), cs_ref); }