From 09a5ea059f17527dd76812824f13ef8d2db33fc1 Mon Sep 17 00:00:00 2001 From: Hans Goudey Date: Sat, 28 Aug 2021 00:40:58 -0500 Subject: Add a slightly more complicated test This one doesn't pass, I'll need to debug it --- source/blender/functions/FN_field.hh | 115 +++++++++++------------- source/blender/functions/intern/field.cc | 28 +++--- source/blender/functions/tests/FN_field_test.cc | 34 ++++++- 3 files changed, 95 insertions(+), 82 deletions(-) diff --git a/source/blender/functions/FN_field.hh b/source/blender/functions/FN_field.hh index 0416eb11d50..65220b6a4a8 100644 --- a/source/blender/functions/FN_field.hh +++ b/source/blender/functions/FN_field.hh @@ -30,6 +30,7 @@ * the embedded multi-function. */ +#include "BLI_string_ref.hh" #include "BLI_vector.hh" #include "FN_multi_function_procedure.hh" @@ -38,64 +39,8 @@ namespace blender::fn { -class Field; - -/** - * An operation acting on data described by fields. Generally corresponds - * to a node or a subset of a node in a node graph. - */ -class FieldFunction { - /** - * The function used to calculate the - */ - std::unique_ptr function_; - - /** - * References to descriptions of the results from the functions this function depends on. - */ - blender::Vector inputs_; - - std::string name_; - - public: - FieldFunction(std::unique_ptr function, - Span inputs, - std::string &&name = "") - : function_(std::move(function)), inputs_(inputs), name_(std::move(name)) - { - } - - Span inputs() const - { - return inputs_; - } - - const MultiFunction &multi_function() const - { - return *function_; - } - - blender::StringRef name() const - { - return name_; - } -}; - -class FieldInput { - std::string name_; - - public: - FieldInput(std::string &&name = "") : name_(std::move(name)) - { - } - - virtual GVArrayPtr retrieve_data(IndexMask mask) const = 0; - - blender::StringRef name() const - { - return name_; - } -}; +class FieldInput; +class FieldFunction; /** * Descibes the output of a function. Generally corresponds to the combination of an output socket @@ -121,13 +66,19 @@ class Field { std::shared_ptr input_; + StringRef name_; + public: - Field(const fn::CPPType &type, std::shared_ptr function, const int output_index) - : type_(&type), function_(function), output_index_(output_index) + Field(const fn::CPPType &type, + std::shared_ptr function, + const int output_index, + StringRef name = "") + : type_(&type), function_(function), output_index_(output_index), name_(name) { } - Field(const fn::CPPType &type, std::shared_ptr input) : type_(&type), input_(input) + Field(const fn::CPPType &type, std::shared_ptr input, StringRef name = "") + : type_(&type), input_(input), name_(name) { } @@ -168,13 +119,47 @@ class Field { blender::StringRef name() const { - if (this->is_function()) { - return function_->name(); - } - return input_->name(); + return name_; } }; +/** + * An operation acting on data described by fields. Generally corresponds + * to a node or a subset of a node in a node graph. + */ +class FieldFunction { + /** + * The function used to calculate the + */ + std::unique_ptr function_; + + /** + * References to descriptions of the results from the functions this function depends on. + */ + blender::Vector inputs_; + + public: + FieldFunction(std::unique_ptr function, Vector &&inputs) + : function_(std::move(function)), inputs_(std::move(inputs)) + { + } + + Span inputs() const + { + return inputs_; + } + + const MultiFunction &multi_function() const + { + return *function_; + } +}; + +class FieldInput { + public: + virtual GVArrayPtr retrieve_data(IndexMask mask) const = 0; +}; + /** * Evaluate more than one field at a time, as an optimization * in case they share inputs or various intermediate values. diff --git a/source/blender/functions/intern/field.cc b/source/blender/functions/intern/field.cc index c4148d0fdc1..39a56b1aeaa 100644 --- a/source/blender/functions/intern/field.cc +++ b/source/blender/functions/intern/field.cc @@ -33,15 +33,15 @@ using VariableMap = Map>; */ using ComputedInputMap = Map; -static MFVariable *get_field_variable(const Field &field, const VariableMap &variable_map) +static MFVariable &get_field_variable(const Field &field, const VariableMap &variable_map) { if (field.is_input()) { const FieldInput &input = field.input(); - return variable_map.lookup(&input).first(); + return *variable_map.lookup(&input).first(); } const FieldFunction &function = field.function(); const Span function_outputs = variable_map.lookup(&function); - return function_outputs[field.function_output_index()]; + return *function_outputs[field.function_output_index()]; } /** @@ -62,17 +62,17 @@ static void add_field_variables_recursive(const Field &field, } else { const FieldFunction &function = field.function(); - for (const Field *input_field : function.inputs()) { - add_field_variables_recursive(*input_field, builder, variable_map); + for (const Field &input_field : function.inputs()) { + add_field_variables_recursive(input_field, builder, variable_map); } /* Add the immediate inputs to this field, which were added earlier in the * recursive call. This will be skipped for functions with no inputs. */ Vector inputs; - for (const Field *input_field : function.inputs()) { - MFVariable *input = get_field_variable(*input_field, variable_map); - builder.add_input_parameter(input->data_type(), input_field->name()); - inputs.append(input); + for (const Field &input_field : function.inputs()) { + MFVariable &input = get_field_variable(input_field, variable_map); + builder.add_input_parameter(input.data_type()); + inputs.append(&input); } Vector outputs = builder.add_call(function.multi_function(), inputs); @@ -96,8 +96,8 @@ static void build_procedure(const Span fields, builder.add_return(); for (const Field &field : fields) { - MFVariable *input = get_field_variable(field, variable_map); - builder.add_output_parameter(*input); + MFVariable &input = get_field_variable(field, variable_map); + builder.add_output_parameter(input); } std::cout << procedure.to_dot(); @@ -122,14 +122,14 @@ static void gather_inputs_recursive(const Field &field, if (!computed_inputs.contains(variable)) { GVArrayPtr data = input.retrieve_data(mask); computed_inputs.add_new(variable); - params.add_readonly_single_input(*data, input.name()); + params.add_readonly_single_input(*data, field.name()); r_inputs.append(std::move(data)); } } else { const FieldFunction &function = field.function(); - for (const Field *input_field : function.inputs()) { - gather_inputs_recursive(*input_field, variable_map, mask, params, computed_inputs, r_inputs); + for (const Field &input_field : function.inputs()) { + gather_inputs_recursive(input_field, variable_map, mask, params, computed_inputs, r_inputs); } } } diff --git a/source/blender/functions/tests/FN_field_test.cc b/source/blender/functions/tests/FN_field_test.cc index 029527249ef..b6b9e84f2f1 100644 --- a/source/blender/functions/tests/FN_field_test.cc +++ b/source/blender/functions/tests/FN_field_test.cc @@ -37,7 +37,7 @@ class IndexFieldInput final : public FieldInput { TEST(field, VArrayInput) { - Field index_field = Field(CPPType::get(), std::make_shared()); + Field index_field = Field(CPPType::get(), std::make_shared(), "Index"); Array result_1(4); GMutableSpan result_generic_1(result_1.as_mutable_span()); @@ -60,8 +60,8 @@ TEST(field, VArrayInput) TEST(field, VArrayInputMultipleOutputs) { std::shared_ptr index_input = std::make_shared(); - Field field_1 = Field(CPPType::get(), index_input); - Field field_2 = Field(CPPType::get(), index_input); + Field field_1 = Field(CPPType::get(), index_input, "Index"); + Field field_2 = Field(CPPType::get(), index_input, "Index"); Array result_1(10); Array result_2(10); @@ -69,10 +69,38 @@ TEST(field, VArrayInputMultipleOutputs) GMutableSpan result_generic_2(result_2.as_mutable_span()); evaluate_fields({field_1, field_2}, {2, 4, 6, 8}, {result_generic_1, result_generic_2}); + EXPECT_EQ(result_1[2], 2); + EXPECT_EQ(result_1[4], 4); + EXPECT_EQ(result_1[6], 6); + EXPECT_EQ(result_1[8], 8); EXPECT_EQ(result_2[2], 2); EXPECT_EQ(result_2[4], 4); EXPECT_EQ(result_2[6], 6); EXPECT_EQ(result_2[8], 8); } +TEST(field, InputAndFunction) +{ + Field index_field = Field(CPPType::get(), std::make_shared(), "Index"); + + Field output_field = Field(CPPType::get(), + std::make_shared( + FieldFunction(std::make_unique>( + "add", [](int a, int b) { return a + b; }), + {index_field, index_field})), + 0); + + std::shared_ptr index_input = std::make_shared(); + Field field_1 = Field(CPPType::get(), index_input); + Field field_2 = Field(CPPType::get(), index_input); + + Array result(10); + GMutableSpan result_generic(result.as_mutable_span()); + evaluate_fields({output_field}, {2, 4, 6, 8}, {result_generic}); + EXPECT_EQ(result[2], 4); + EXPECT_EQ(result[4], 8); + EXPECT_EQ(result[6], 12); + EXPECT_EQ(result[8], 16); +} + } // namespace blender::fn::tests -- cgit v1.2.3