Skip to content

Commit 508f906

Browse files
committed
[tmva][sofie] Add tests for Not, IsInf and IsNaN operators
Remove also some commented code in the NonZero operator
1 parent 8da3cfd commit 508f906

6 files changed

Lines changed: 61 additions & 14 deletions

File tree

tmva/sofie/inc/TMVA/ROperator_NonZero.hxx

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,6 @@ public:
9292
fShapeY.resize(2);
9393
fShapeY[0] = fShapeX.size();
9494

95-
// identify as -1 since we will declare maximum as size of input
96-
// auto inputLength = ConvertDimShapeToLength(fShapeX);
97-
// // case X is Dim, becomes complicated to know the maximum. Shuld be allocated dynamically
98-
// size_t inputLength = 0;
99-
// if (!model.IsDynamicTensor(fNX)) {
100-
// inputLength = ConvertShapeToLength(ConvertShapeToInt(fShapeX));
101-
// else
102-
// inputLength = static_cast<size_t>(-1); // flag -1 to define shape correctly
103-
10495
// flag -1 to define the shape variable in the constructor code and not in the constructor signature
10596
fShapeY[1] = Dim{std::string("v_NonZero_") + fNX, static_cast<size_t>(-1) };
10697

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,5 +2973,36 @@ TEST(ONNX, NonZero_Constant)
29732973
EXPECT_EQ(output[i] , correct_output[i]);
29742974
}
29752975
}
2976+
TEST(ONNX, IsInf)
2977+
{
2978+
// expected input
2979+
std::vector<float> input = { 1, static_cast<float>(1./0.), 2.};
2980+
std::vector<uint8_t> correct_output = { 0,1,0 };
2981+
2982+
// not cannot use input.size() in string because input symbol will not be visible when running inference
2983+
ASSERT_INCLUDE_AND_RUN_SESSION_ARGS(std::vector<uint8_t>, "IsInf",std::string("\"\", ") + std::to_string(input.size()), input.size(),input);
2984+
2985+
// Checking output size
2986+
EXPECT_EQ(output.size(), correct_output.size());
2987+
// Checking output
2988+
for (size_t i = 0; i < output.size(); ++i) {
2989+
EXPECT_EQ(output[i] , correct_output[i]);
2990+
}
2991+
}
29762992

2993+
TEST(ONNX, NotIsNaN)
2994+
{
2995+
// expected input
2996+
std::vector<float> input = { 1, static_cast<float>(0./0.), 2.};
2997+
std::vector<uint8_t> correct_output = { 1,0,1 };
2998+
2999+
ASSERT_INCLUDE_AND_RUN_SESSION_ARGS(std::vector<uint8_t>, "NotIsNaN",std::string("\"\", ") + std::to_string(input.size()), input.size(),input);
3000+
3001+
// Checking output size
3002+
EXPECT_EQ(output.size(), correct_output.size());
3003+
// Checking output
3004+
for (size_t i = 0; i < output.size(); ++i) {
3005+
EXPECT_EQ(output[i] , correct_output[i]);
3006+
}
3007+
}
29773008

tmva/sofie/test/input_models/IsInf.onnx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
 onnx-example:S
2+

3+
inputoutput"IsInfTestZ
4+
input
5+

6+

7+
Nb
8+
output
9+
 
10+

11+
NB

tmva/sofie/test/input_models/NotIsNaN.onnx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
 onnx-example:t
2+

3+
input temp_result"IsNaN
4+

5+
temp_resultoutput"NotTestZ
6+
input
7+

8+

9+
Nb
10+
output
11+
 
12+

13+
NB

tmva/sofie/test/test_helpers.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ bool includeModel(std::string const &modelName)
4545
template <class T>
4646
std::string toInterpreter(T const &ptr, std::string const &className, bool toRawPointer = false)
4747
{
48-
if constexpr (std::is_same_v<T, int>) {
48+
// for the integer arguments (shape values)
49+
if constexpr (std::is_same_v<T, int> || std::is_same_v<T, size_t>) {
4950
return std::to_string(ptr);
5051
}
52+
// for the data arguments
5153
std::string out =
5254
TString::Format("reinterpret_cast<%s*>(0x%zx)", className.c_str(), reinterpret_cast<std::size_t>(&ptr)).Data();
5355
if (toRawPointer) {
@@ -94,6 +96,8 @@ runModel(std::string outputTypeName, std::string const &modelName, std::string s
9496
auto type_name = []<typename T>() {
9597
if constexpr (std::is_same_v<T, int>)
9698
return "int";
99+
else if constexpr (std::is_same_v<T, size_t>)
100+
return "size_t";
97101
else if constexpr (std::is_same_v<T, std::vector<float>>)
98102
return "std::vector<float>";
99103
else if constexpr (std::is_same_v<T, std::vector<int>>)

tmva/sofie_parsers/src/ParseBasicIs.cxx

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@ namespace SOFIE {
99
template <EBasicIsOperator Op>
1010
std::unique_ptr<ROperator> ParseBasicIs(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
1111
{
12-
ETensorType input_type = ETensorType::UNDEFINED;
1312

1413
std::string input_name = nodeproto.input(0);
15-
if (parser.IsRegisteredTensorType(input_name)) {
16-
input_type = parser.GetTensorType(input_name);
17-
} else {
14+
if (!parser.IsRegisteredTensorType(input_name)) {
1815
throw
1916
std::runtime_error("TMVA::SOFIE ONNX Parser " + IsOpTraits<Op>::Name() + " op has input tensor " + input_name +
2017
" but its type is not yet registered");

0 commit comments

Comments
 (0)