/* Apache License, Version 2.0 */ #include "testing/testing.h" #include "FN_multi_function_builder.hh" #include "FN_multi_function_network.hh" #include "FN_multi_function_network_evaluation.hh" namespace blender::fn::tests { TEST(multi_function_network, Test1) { CustomMF_SI_SO add_10_fn("add 10", [](int value) { return value + 10; }); CustomMF_SI_SI_SO multiply_fn("multiply", [](int a, int b) { return a * b; }); MFNetwork network; MFNode &node1 = network.add_function(add_10_fn); MFNode &node2 = network.add_function(multiply_fn); MFOutputSocket &input_socket = network.add_input("Input", MFDataType::ForSingle()); MFInputSocket &output_socket = network.add_output("Output", MFDataType::ForSingle()); network.add_link(node1.output(0), node2.input(0)); network.add_link(node1.output(0), node2.input(1)); network.add_link(node2.output(0), output_socket); network.add_link(input_socket, node1.input(0)); MFNetworkEvaluator network_fn{{&input_socket}, {&output_socket}}; { Array values = {4, 6, 1, 2, 0}; Array results(values.size(), 0); MFParamsBuilder params(network_fn, values.size()); params.add_readonly_single_input(values.as_span()); params.add_uninitialized_single_output(results.as_mutable_span()); MFContextBuilder context; network_fn.call({0, 2, 3, 4}, params, context); EXPECT_EQ(results[0], 14 * 14); EXPECT_EQ(results[1], 0); EXPECT_EQ(results[2], 11 * 11); EXPECT_EQ(results[3], 12 * 12); EXPECT_EQ(results[4], 10 * 10); } { int value = 3; Array results(5, 0); MFParamsBuilder params(network_fn, results.size()); params.add_readonly_single_input(&value); params.add_uninitialized_single_output(results.as_mutable_span()); MFContextBuilder context; network_fn.call({1, 2, 4}, params, context); EXPECT_EQ(results[0], 0); EXPECT_EQ(results[1], 13 * 13); EXPECT_EQ(results[2], 13 * 13); EXPECT_EQ(results[3], 0); EXPECT_EQ(results[4], 13 * 13); } } class ConcatVectorsFunction : public MultiFunction { public: ConcatVectorsFunction() { MFSignatureBuilder signature = this->get_builder("Concat Vectors"); signature.vector_mutable("A"); signature.vector_input("B"); } void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override { GVectorArrayRef a = params.vector_mutable(0); VArraySpan b = params.readonly_vector_input(1); for (int64_t i : mask) { a.extend(i, b[i]); } } }; class AppendFunction : public MultiFunction { public: AppendFunction() { MFSignatureBuilder signature = this->get_builder("Append"); signature.vector_mutable("Vector"); signature.single_input("Value"); } void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override { GVectorArrayRef vectors = params.vector_mutable(0); VSpan values = params.readonly_single_input(1); for (int64_t i : mask) { vectors.append(i, values[i]); } } }; class SumVectorFunction : public MultiFunction { public: SumVectorFunction() { MFSignatureBuilder signature = this->get_builder("Sum Vector"); signature.vector_input("Vector"); signature.single_output("Sum"); } void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override { VArraySpan vectors = params.readonly_vector_input(0); MutableSpan sums = params.uninitialized_single_output(1); for (int64_t i : mask) { int sum = 0; VSpan vector = vectors[i]; for (int j = 0; j < vector.size(); j++) { sum += vector[j]; } sums[i] = sum; } } }; class CreateRangeFunction : public MultiFunction { public: CreateRangeFunction() { MFSignatureBuilder builder = this->get_builder("Create Range"); builder.single_input("Size"); builder.vector_output("Range"); } void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override { VSpan sizes = params.readonly_single_input(0, "Size"); GVectorArrayRef ranges = params.vector_output(1, "Range"); for (int64_t i : mask) { int size = sizes[i]; for (int j : IndexRange(size)) { ranges.append(i, j); } } } }; TEST(multi_function_network, Test2) { CustomMF_SI_SO add_3_fn("add 3", [](int value) { return value + 3; }); ConcatVectorsFunction concat_vectors_fn; AppendFunction append_fn; SumVectorFunction sum_fn; CreateRangeFunction create_range_fn; MFNetwork network; MFOutputSocket &input1 = network.add_input("Input 1", MFDataType::ForVector()); MFOutputSocket &input2 = network.add_input("Input 2", MFDataType::ForSingle()); MFInputSocket &output1 = network.add_output("Output 1", MFDataType::ForVector()); MFInputSocket &output2 = network.add_output("Output 2", MFDataType::ForSingle()); MFNode &node1 = network.add_function(add_3_fn); MFNode &node2 = network.add_function(create_range_fn); MFNode &node3 = network.add_function(concat_vectors_fn); MFNode &node4 = network.add_function(sum_fn); MFNode &node5 = network.add_function(append_fn); MFNode &node6 = network.add_function(sum_fn); network.add_link(input2, node1.input(0)); network.add_link(node1.output(0), node2.input(0)); network.add_link(node2.output(0), node3.input(1)); network.add_link(input1, node3.input(0)); network.add_link(input1, node4.input(0)); network.add_link(node4.output(0), node5.input(1)); network.add_link(node3.output(0), node5.input(0)); network.add_link(node5.output(0), node6.input(0)); network.add_link(node3.output(0), output1); network.add_link(node6.output(0), output2); // std::cout << network.to_dot() << "\n\n"; MFNetworkEvaluator network_fn{{&input1, &input2}, {&output1, &output2}}; { Array input_value_1 = {3, 6}; int input_value_2 = 4; GVectorArray output_value_1(CPPType::get(), 5); Array output_value_2(5, -1); MFParamsBuilder params(network_fn, 5); params.add_readonly_vector_input(GVArraySpan(input_value_1.as_span(), 5)); params.add_readonly_single_input(&input_value_2); params.add_vector_output(output_value_1); params.add_uninitialized_single_output(output_value_2.as_mutable_span()); MFContextBuilder context; network_fn.call({1, 2, 4}, params, context); EXPECT_EQ(output_value_1[0].size(), 0); EXPECT_EQ(output_value_1[1].size(), 9); EXPECT_EQ(output_value_1[2].size(), 9); EXPECT_EQ(output_value_1[3].size(), 0); EXPECT_EQ(output_value_1[4].size(), 9); EXPECT_EQ(output_value_2[0], -1); EXPECT_EQ(output_value_2[1], 39); EXPECT_EQ(output_value_2[2], 39); EXPECT_EQ(output_value_2[3], -1); EXPECT_EQ(output_value_2[4], 39); } { GVectorArray input_value_1(CPPType::get(), 3); GVectorArrayRef input_value_ref_1 = input_value_1; input_value_ref_1.extend(0, {3, 4, 5}); input_value_ref_1.extend(1, {1, 2}); Array input_value_2 = {4, 2, 3}; GVectorArray output_value_1(CPPType::get(), 3); Array output_value_2(3, -1); MFParamsBuilder params(network_fn, 3); params.add_readonly_vector_input(input_value_1); params.add_readonly_single_input(input_value_2.as_span()); params.add_vector_output(output_value_1); params.add_uninitialized_single_output(output_value_2.as_mutable_span()); MFContextBuilder context; network_fn.call({0, 1, 2}, params, context); EXPECT_EQ(output_value_1[0].size(), 10); EXPECT_EQ(output_value_1[1].size(), 7); EXPECT_EQ(output_value_1[2].size(), 6); EXPECT_EQ(output_value_2[0], 45); EXPECT_EQ(output_value_2[1], 16); EXPECT_EQ(output_value_2[2], 15); } } } // namespace blender::fn::tests