diff options
Diffstat (limited to 'source/blender/functions/tests/FN_multi_function_network_test.cc')
-rw-r--r-- | source/blender/functions/tests/FN_multi_function_network_test.cc | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/source/blender/functions/tests/FN_multi_function_network_test.cc b/source/blender/functions/tests/FN_multi_function_network_test.cc new file mode 100644 index 00000000000..f226e0eac2e --- /dev/null +++ b/source/blender/functions/tests/FN_multi_function_network_test.cc @@ -0,0 +1,255 @@ +/* Apache License, Version 2.0 */ + +#include "testing/testing.h" + +#include "FN_multi_function_builder.hh" +#include "FN_multi_function_network.hh" +#include "FN_multi_function_network_evaluation.hh" + +namespace blender::fn::tests { +namespace { + +TEST(multi_function_network, Test1) +{ + CustomMF_SI_SO<int, int> add_10_fn("add 10", [](int value) { return value + 10; }); + CustomMF_SI_SI_SO<int, int, int> multiply_fn("multiply", [](int a, int b) { return a * b; }); + + MFNetwork network; + + MFNode &node1 = network.add_function(add_10_fn); + MFNode &node2 = network.add_function(multiply_fn); + MFOutputSocket &input_socket = network.add_input("Input", MFDataType::ForSingle<int>()); + MFInputSocket &output_socket = network.add_output("Output", MFDataType::ForSingle<int>()); + network.add_link(node1.output(0), node2.input(0)); + network.add_link(node1.output(0), node2.input(1)); + network.add_link(node2.output(0), output_socket); + network.add_link(input_socket, node1.input(0)); + + MFNetworkEvaluator network_fn{{&input_socket}, {&output_socket}}; + + { + Array<int> values = {4, 6, 1, 2, 0}; + Array<int> results(values.size(), 0); + + MFParamsBuilder params(network_fn, values.size()); + params.add_readonly_single_input(values.as_span()); + params.add_uninitialized_single_output(results.as_mutable_span()); + + MFContextBuilder context; + + network_fn.call({0, 2, 3, 4}, params, context); + + EXPECT_EQ(results[0], 14 * 14); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 11 * 11); + EXPECT_EQ(results[3], 12 * 12); + EXPECT_EQ(results[4], 10 * 10); + } + { + int value = 3; + Array<int> results(5, 0); + + MFParamsBuilder params(network_fn, results.size()); + params.add_readonly_single_input(&value); + params.add_uninitialized_single_output(results.as_mutable_span()); + + MFContextBuilder context; + + network_fn.call({1, 2, 4}, params, context); + + EXPECT_EQ(results[0], 0); + EXPECT_EQ(results[1], 13 * 13); + EXPECT_EQ(results[2], 13 * 13); + EXPECT_EQ(results[3], 0); + EXPECT_EQ(results[4], 13 * 13); + } +} + +class ConcatVectorsFunction : public MultiFunction { + public: + ConcatVectorsFunction() + { + MFSignatureBuilder signature = this->get_builder("Concat Vectors"); + signature.vector_mutable<int>("A"); + signature.vector_input<int>("B"); + } + + void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override + { + GVectorArrayRef<int> a = params.vector_mutable<int>(0); + VArraySpan<int> b = params.readonly_vector_input<int>(1); + + for (int64_t i : mask) { + a.extend(i, b[i]); + } + } +}; + +class AppendFunction : public MultiFunction { + public: + AppendFunction() + { + MFSignatureBuilder signature = this->get_builder("Append"); + signature.vector_mutable<int>("Vector"); + signature.single_input<int>("Value"); + } + + void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override + { + GVectorArrayRef<int> vectors = params.vector_mutable<int>(0); + VSpan<int> values = params.readonly_single_input<int>(1); + + for (int64_t i : mask) { + vectors.append(i, values[i]); + } + } +}; + +class SumVectorFunction : public MultiFunction { + public: + SumVectorFunction() + { + MFSignatureBuilder signature = this->get_builder("Sum Vector"); + signature.vector_input<int>("Vector"); + signature.single_output<int>("Sum"); + } + + void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override + { + VArraySpan<int> vectors = params.readonly_vector_input<int>(0); + MutableSpan<int> sums = params.uninitialized_single_output<int>(1); + + for (int64_t i : mask) { + int sum = 0; + VSpan<int> vector = vectors[i]; + for (int j = 0; j < vector.size(); j++) { + sum += vector[j]; + } + sums[i] = sum; + } + } +}; + +class CreateRangeFunction : public MultiFunction { + public: + CreateRangeFunction() + { + MFSignatureBuilder builder = this->get_builder("Create Range"); + builder.single_input<int>("Size"); + builder.vector_output<int>("Range"); + } + + void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override + { + VSpan<int> sizes = params.readonly_single_input<int>(0, "Size"); + GVectorArrayRef<int> ranges = params.vector_output<int>(1, "Range"); + + for (int64_t i : mask) { + int size = sizes[i]; + for (int j : IndexRange(size)) { + ranges.append(i, j); + } + } + } +}; + +TEST(multi_function_network, Test2) +{ + CustomMF_SI_SO<int, int> add_3_fn("add 3", [](int value) { return value + 3; }); + + ConcatVectorsFunction concat_vectors_fn; + AppendFunction append_fn; + SumVectorFunction sum_fn; + CreateRangeFunction create_range_fn; + + MFNetwork network; + + MFOutputSocket &input1 = network.add_input("Input 1", MFDataType::ForVector<int>()); + MFOutputSocket &input2 = network.add_input("Input 2", MFDataType::ForSingle<int>()); + MFInputSocket &output1 = network.add_output("Output 1", MFDataType::ForVector<int>()); + MFInputSocket &output2 = network.add_output("Output 2", MFDataType::ForSingle<int>()); + + MFNode &node1 = network.add_function(add_3_fn); + MFNode &node2 = network.add_function(create_range_fn); + MFNode &node3 = network.add_function(concat_vectors_fn); + MFNode &node4 = network.add_function(sum_fn); + MFNode &node5 = network.add_function(append_fn); + MFNode &node6 = network.add_function(sum_fn); + + network.add_link(input2, node1.input(0)); + network.add_link(node1.output(0), node2.input(0)); + network.add_link(node2.output(0), node3.input(1)); + network.add_link(input1, node3.input(0)); + network.add_link(input1, node4.input(0)); + network.add_link(node4.output(0), node5.input(1)); + network.add_link(node3.output(0), node5.input(0)); + network.add_link(node5.output(0), node6.input(0)); + network.add_link(node3.output(0), output1); + network.add_link(node6.output(0), output2); + + // std::cout << network.to_dot() << "\n\n"; + + MFNetworkEvaluator network_fn{{&input1, &input2}, {&output1, &output2}}; + + { + Array<int> input_value_1 = {3, 6}; + int input_value_2 = 4; + + GVectorArray output_value_1(CPPType::get<int32_t>(), 5); + Array<int> output_value_2(5, -1); + + MFParamsBuilder params(network_fn, 5); + params.add_readonly_vector_input(GVArraySpan(input_value_1.as_span(), 5)); + params.add_readonly_single_input(&input_value_2); + params.add_vector_output(output_value_1); + params.add_uninitialized_single_output(output_value_2.as_mutable_span()); + + MFContextBuilder context; + + network_fn.call({1, 2, 4}, params, context); + + EXPECT_EQ(output_value_1[0].size(), 0); + EXPECT_EQ(output_value_1[1].size(), 9); + EXPECT_EQ(output_value_1[2].size(), 9); + EXPECT_EQ(output_value_1[3].size(), 0); + EXPECT_EQ(output_value_1[4].size(), 9); + + EXPECT_EQ(output_value_2[0], -1); + EXPECT_EQ(output_value_2[1], 39); + EXPECT_EQ(output_value_2[2], 39); + EXPECT_EQ(output_value_2[3], -1); + EXPECT_EQ(output_value_2[4], 39); + } + { + GVectorArray input_value_1(CPPType::get<int32_t>(), 3); + GVectorArrayRef<int> input_value_ref_1 = input_value_1; + input_value_ref_1.extend(0, {3, 4, 5}); + input_value_ref_1.extend(1, {1, 2}); + + Array<int> input_value_2 = {4, 2, 3}; + + GVectorArray output_value_1(CPPType::get<int32_t>(), 3); + Array<int> output_value_2(3, -1); + + MFParamsBuilder params(network_fn, 3); + params.add_readonly_vector_input(input_value_1); + params.add_readonly_single_input(input_value_2.as_span()); + params.add_vector_output(output_value_1); + params.add_uninitialized_single_output(output_value_2.as_mutable_span()); + + MFContextBuilder context; + + network_fn.call({0, 1, 2}, params, context); + + EXPECT_EQ(output_value_1[0].size(), 10); + EXPECT_EQ(output_value_1[1].size(), 7); + EXPECT_EQ(output_value_1[2].size(), 6); + + EXPECT_EQ(output_value_2[0], 45); + EXPECT_EQ(output_value_2[1], 16); + EXPECT_EQ(output_value_2[2], 15); + } +} + +} // namespace +} // namespace blender::fn::tests |