Skip to content

Commit b03ae5c

Browse files
committed
[tmva][sofie] Add tests for Not, IsInf and IsNaN operators
1 parent ce2933b commit b03ae5c

4 files changed

Lines changed: 60 additions & 1 deletion

File tree

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,5 +2973,36 @@ TEST(ONNX, NonZero_Constant)
29732973
EXPECT_EQ(output[i] , correct_output[i]);
29742974
}
29752975
}
2976+
TEST(ONNX, IsInf)
2977+
{
2978+
// expected input
2979+
std::vector<float> input = { 1, static_cast<float>(1./0.), 2.};
2980+
std::vector<uint8_t> correct_output = { 0,1,0 };
2981+
2982+
// not cannot use input.size() in string because input symbol will not be visible when running inference
2983+
ASSERT_INCLUDE_AND_RUN_SESSION_ARGS(std::vector<uint8_t>, "IsInf",std::string("\"\", ") + std::to_string(input.size()), input.size(),input);
2984+
2985+
// Checking output size
2986+
EXPECT_EQ(output.size(), correct_output.size());
2987+
// Checking output
2988+
for (size_t i = 0; i < output.size(); ++i) {
2989+
EXPECT_EQ(output[i] , correct_output[i]);
2990+
}
2991+
}
29762992

2993+
TEST(ONNX, NotIsNaN)
2994+
{
2995+
// expected input
2996+
std::vector<float> input = { 1, static_cast<float>(0./0.), 2.};
2997+
std::vector<uint8_t> correct_output = { 1,0,1 };
2998+
2999+
ASSERT_INCLUDE_AND_RUN_SESSION_ARGS(std::vector<uint8_t>, "NotIsNaN",std::string("\"\", ") + std::to_string(input.size()), input.size(),input);
3000+
3001+
// Checking output size
3002+
EXPECT_EQ(output.size(), correct_output.size());
3003+
// Checking output
3004+
for (size_t i = 0; i < output.size(); ++i) {
3005+
EXPECT_EQ(output[i] , correct_output[i]);
3006+
}
3007+
}
29773008

tmva/sofie/test/input_models/IsInf.onnx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
 onnx-example:S
2+

3+
inputoutput"IsInfTestZ
4+
input
5+

6+

7+
Nb
8+
output
9+
 
10+

11+
NB

tmva/sofie/test/input_models/NotIsNaN.onnx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
 onnx-example:t
2+

3+
input temp_result"IsNaN
4+

5+
temp_resultoutput"NotTestZ
6+
input
7+

8+

9+
Nb
10+
output
11+
 
12+

13+
NB

tmva/sofie/test/test_helpers.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ bool includeModel(std::string const &modelName)
4545
template <class T>
4646
std::string toInterpreter(T const &ptr, std::string const &className, bool toRawPointer = false)
4747
{
48-
if constexpr (std::is_same_v<T, int>) {
48+
// for the integer arguments (shape values)
49+
if constexpr (std::is_same_v<T, int> || std::is_same_v<T, size_t>) {
4950
return std::to_string(ptr);
5051
}
52+
// for the data arguments
5153
std::string out =
5254
TString::Format("reinterpret_cast<%s*>(0x%zx)", className.c_str(), reinterpret_cast<std::size_t>(&ptr)).Data();
5355
if (toRawPointer) {
@@ -94,6 +96,8 @@ runModel(std::string outputTypeName, std::string const &modelName, std::string s
9496
auto type_name = []<typename T>() {
9597
if constexpr (std::is_same_v<T, int>)
9698
return "int";
99+
else if constexpr (std::is_same_v<T, size_t>)
100+
return "size_t";
97101
else if constexpr (std::is_same_v<T, std::vector<float>>)
98102
return "std::vector<float>";
99103
else if constexpr (std::is_same_v<T, std::vector<int>>)

0 commit comments

Comments
 (0)