diff options
Diffstat (limited to 'source/blender/functions/intern')
4 files changed, 650 insertions, 24 deletions
diff --git a/source/blender/functions/intern/multi_function_builder.cc b/source/blender/functions/intern/multi_function_builder.cc new file mode 100644 index 00000000000..889a2595aab --- /dev/null +++ b/source/blender/functions/intern/multi_function_builder.cc @@ -0,0 +1,90 @@ +/* + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software Foundation, + * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + */ + +#include "FN_multi_function_builder.hh" + +#include "BLI_hash.hh" + +namespace blender::fn { + +CustomMF_GenericConstant::CustomMF_GenericConstant(const CPPType &type, const void *value) + : type_(type), value_(value) +{ + MFSignatureBuilder signature = this->get_builder("Constant " + type.name()); + std::stringstream ss; + type.debug_print(value, ss); + signature.single_output(ss.str(), type); +} + +void CustomMF_GenericConstant::call(IndexMask mask, + MFParams params, + MFContext UNUSED(context)) const +{ + GMutableSpan output = params.uninitialized_single_output(0); + type_.fill_uninitialized_indices(value_, output.buffer(), mask); +} + +uint CustomMF_GenericConstant::hash() const +{ + return type_.hash(value_); +} + +bool CustomMF_GenericConstant::equals(const MultiFunction &other) const +{ + const CustomMF_GenericConstant *_other = dynamic_cast<const CustomMF_GenericConstant *>(&other); + if (_other == nullptr) { + return false; + } + if (type_ != _other->type_) { + return false; + } + return type_.is_equal(value_, _other->value_); +} + +static std::string gspan_to_string(GSpan array) +{ + std::stringstream ss; + ss << "["; + uint max_amount = 5; + for (uint i : IndexRange(std::min(max_amount, array.size()))) { + array.type().debug_print(array[i], ss); + ss << ", "; + } + if (max_amount < array.size()) { + ss << "..."; + } + ss << "]"; + return ss.str(); +} + +CustomMF_GenericConstantArray::CustomMF_GenericConstantArray(GSpan array) : array_(array) +{ + const CPPType &type = array.type(); + MFSignatureBuilder signature = this->get_builder("Constant " + type.name() + " Vector"); + signature.vector_output(gspan_to_string(array), type); +} + +void CustomMF_GenericConstantArray::call(IndexMask mask, + MFParams params, + MFContext UNUSED(context)) const +{ + GVectorArray &vectors = params.vector_output(0); + for (uint i : mask) { + vectors.extend(i, array_); + } +} + +} // namespace blender::fn diff --git a/source/blender/functions/intern/multi_function_network.cc b/source/blender/functions/intern/multi_function_network.cc index 5df70d92a4e..11c9c065f51 100644 --- a/source/blender/functions/intern/multi_function_network.cc +++ b/source/blender/functions/intern/multi_function_network.cc @@ -15,6 +15,8 @@ */ #include "BLI_dot_export.hh" +#include "BLI_stack.hh" + #include "FN_multi_function_network.hh" namespace blender::fn { @@ -184,17 +186,18 @@ void MFNetwork::add_link(MFOutputSocket &from, MFInputSocket &to) MFOutputSocket &MFNetwork::add_input(StringRef name, MFDataType data_type) { - return this->add_dummy(name, {}, {data_type}, {}, {name}).output(0); + return this->add_dummy(name, {}, {data_type}, {}, {"Value"}).output(0); } MFInputSocket &MFNetwork::add_output(StringRef name, MFDataType data_type) { - return this->add_dummy(name, {data_type}, {}, {name}, {}).input(0); + return this->add_dummy(name, {data_type}, {}, {"Value"}, {}).input(0); } void MFNetwork::relink(MFOutputSocket &old_output, MFOutputSocket &new_output) { BLI_assert(&old_output != &new_output); + BLI_assert(old_output.data_type_ == new_output.data_type_); for (MFInputSocket *input : old_output.targets()) { input->origin_ = &new_output; } @@ -230,7 +233,43 @@ void MFNetwork::remove(MFNode &node) node_or_null_by_id_[node.id_] = nullptr; } -std::string MFNetwork::to_dot() const +void MFNetwork::remove(Span<MFNode *> nodes) +{ + for (MFNode *node : nodes) { + this->remove(*node); + } +} + +void MFNetwork::find_dependencies(Span<const MFInputSocket *> sockets, + VectorSet<const MFOutputSocket *> &r_dummy_sockets, + VectorSet<const MFInputSocket *> &r_unlinked_inputs) const +{ + Set<const MFNode *> visited_nodes; + Stack<const MFInputSocket *> sockets_to_check; + sockets_to_check.push_multiple(sockets); + + while (!sockets_to_check.is_empty()) { + const MFInputSocket &socket = *sockets_to_check.pop(); + const MFOutputSocket *origin_socket = socket.origin(); + if (origin_socket == nullptr) { + r_unlinked_inputs.add(&socket); + continue; + } + + const MFNode &origin_node = origin_socket->node(); + + if (origin_node.is_dummy()) { + r_dummy_sockets.add(origin_socket); + continue; + } + + if (visited_nodes.add(&origin_node)) { + sockets_to_check.push_multiple(origin_node.inputs()); + } + } +} + +std::string MFNetwork::to_dot(Span<const MFNode *> marked_nodes) const { dot::DirectedGraph digraph; digraph.set_rankdir(dot::Attr_rankdir::LeftToRight); @@ -256,6 +295,13 @@ std::string MFNetwork::to_dot() const dot_nodes.add_new(node, dot_node_ref); } + for (const MFDummyNode *node : dummy_nodes_) { + dot_nodes.lookup(node).node().set_background_color("#77EE77"); + } + for (const MFNode *node : marked_nodes) { + dot_nodes.lookup(node).node().set_background_color("#7777EE"); + } + for (const MFNode *to_node : all_nodes) { dot::NodeWithSocketsRef to_dot_node = dot_nodes.lookup(to_node); diff --git a/source/blender/functions/intern/multi_function_network_evaluation.cc b/source/blender/functions/intern/multi_function_network_evaluation.cc index 08a254dc300..b59cbc6a1a2 100644 --- a/source/blender/functions/intern/multi_function_network_evaluation.cc +++ b/source/blender/functions/intern/multi_function_network_evaluation.cc @@ -58,7 +58,7 @@ class MFNetworkEvaluationStorage { uint min_array_size_; public: - MFNetworkEvaluationStorage(IndexMask mask, uint max_socket_id); + MFNetworkEvaluationStorage(IndexMask mask, uint socket_id_amount); ~MFNetworkEvaluationStorage(); /* Add the values that have been provided by the caller of the multi-function network. */ @@ -106,30 +106,30 @@ MFNetworkEvaluator::MFNetworkEvaluator(Vector<const MFOutputSocket *> inputs, BLI_assert(outputs_.size() > 0); MFSignatureBuilder signature = this->get_builder("Function Tree"); - for (auto socket : inputs_) { + for (const MFOutputSocket *socket : inputs_) { BLI_assert(socket->node().is_dummy()); MFDataType type = socket->data_type(); switch (type.category()) { case MFDataType::Single: - signature.single_input("Input", type.single_type()); + signature.single_input(socket->name(), type.single_type()); break; case MFDataType::Vector: - signature.vector_input("Input", type.vector_base_type()); + signature.vector_input(socket->name(), type.vector_base_type()); break; } } - for (auto socket : outputs_) { + for (const MFInputSocket *socket : outputs_) { BLI_assert(socket->node().is_dummy()); MFDataType type = socket->data_type(); switch (type.category()) { case MFDataType::Single: - signature.single_output("Output", type.single_type()); + signature.single_output(socket->name(), type.single_type()); break; case MFDataType::Vector: - signature.vector_output("Output", type.vector_base_type()); + signature.vector_output(socket->name(), type.vector_base_type()); break; } } @@ -142,7 +142,7 @@ void MFNetworkEvaluator::call(IndexMask mask, MFParams params, MFContext context } const MFNetwork &network = outputs_[0]->node().network(); - Storage storage(mask, network.max_socket_id()); + Storage storage(mask, network.socket_id_amount()); Vector<const MFInputSocket *> outputs_to_initialize_in_the_end; @@ -219,8 +219,6 @@ BLI_NOINLINE void MFNetworkEvaluator::evaluate_network_to_compute_outputs( sockets_to_compute.push(socket->origin()); } - Vector<const MFOutputSocket *, 32> missing_sockets; - /* This is the main loop that traverses the MFNetwork. */ while (!sockets_to_compute.is_empty()) { const MFOutputSocket &socket = *sockets_to_compute.peek(); @@ -235,17 +233,18 @@ BLI_NOINLINE void MFNetworkEvaluator::evaluate_network_to_compute_outputs( BLI_assert(node.all_inputs_have_origin()); const MFFunctionNode &function_node = node.as_function(); - missing_sockets.clear(); - function_node.foreach_origin_socket([&](const MFOutputSocket &origin) { - if (!storage.socket_is_computed(origin)) { - missing_sockets.append(&origin); + bool all_origins_are_computed = true; + for (const MFInputSocket *input_socket : function_node.inputs()) { + const MFOutputSocket *origin = input_socket->origin(); + if (origin != nullptr) { + if (!storage.socket_is_computed(*origin)) { + sockets_to_compute.push(origin); + all_origins_are_computed = false; + } } - }); - - sockets_to_compute.push_multiple(missing_sockets); + } - bool all_inputs_are_computed = missing_sockets.size() == 0; - if (all_inputs_are_computed) { + if (all_origins_are_computed) { this->evaluate_function(global_context, function_node, storage); sockets_to_compute.pop(); } @@ -507,9 +506,9 @@ struct OwnVectorValue : public Value { /** \name Storage methods * \{ */ -MFNetworkEvaluationStorage::MFNetworkEvaluationStorage(IndexMask mask, uint max_socket_id) +MFNetworkEvaluationStorage::MFNetworkEvaluationStorage(IndexMask mask, uint socket_id_amount) : mask_(mask), - value_per_output_id_(max_socket_id + 1, nullptr), + value_per_output_id_(socket_id_amount, nullptr), min_array_size_(mask.min_array_size()) { } diff --git a/source/blender/functions/intern/multi_function_network_optimization.cc b/source/blender/functions/intern/multi_function_network_optimization.cc new file mode 100644 index 00000000000..849b24a318f --- /dev/null +++ b/source/blender/functions/intern/multi_function_network_optimization.cc @@ -0,0 +1,491 @@ +/* + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software Foundation, + * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + */ + +/** \file + * \ingroup fn + */ + +/* Used to check if two multi-functions have the exact same type. */ +#include <typeinfo> + +#include "FN_multi_function_builder.hh" +#include "FN_multi_function_network_evaluation.hh" +#include "FN_multi_function_network_optimization.hh" + +#include "BLI_disjoint_set.hh" +#include "BLI_ghash.h" +#include "BLI_map.hh" +#include "BLI_rand.h" +#include "BLI_stack.hh" + +namespace blender::fn::mf_network_optimization { + +/* -------------------------------------------------------------------- */ +/** \name Utility functions to find nodes in a network. + * + * \{ */ + +static bool set_tag_and_check_if_modified(bool &tag, bool new_value) +{ + if (tag != new_value) { + tag = new_value; + return true; + } + else { + return false; + } +} + +static Array<bool> mask_nodes_to_the_left(MFNetwork &network, Span<MFNode *> nodes) +{ + Array<bool> is_to_the_left(network.node_id_amount(), false); + Stack<MFNode *> nodes_to_check; + + for (MFNode *node : nodes) { + is_to_the_left[node->id()] = true; + nodes_to_check.push(node); + } + + while (!nodes_to_check.is_empty()) { + MFNode &node = *nodes_to_check.pop(); + + for (MFInputSocket *input_socket : node.inputs()) { + MFOutputSocket *origin = input_socket->origin(); + if (origin != nullptr) { + MFNode &origin_node = origin->node(); + if (set_tag_and_check_if_modified(is_to_the_left[origin_node.id()], true)) { + nodes_to_check.push(&origin_node); + } + } + } + } + + return is_to_the_left; +} + +static Array<bool> mask_nodes_to_the_right(MFNetwork &network, Span<MFNode *> nodes) +{ + Array<bool> is_to_the_right(network.node_id_amount(), false); + Stack<MFNode *> nodes_to_check; + + for (MFNode *node : nodes) { + is_to_the_right[node->id()] = true; + nodes_to_check.push(node); + } + + while (!nodes_to_check.is_empty()) { + MFNode &node = *nodes_to_check.pop(); + + for (MFOutputSocket *output_socket : node.outputs()) { + for (MFInputSocket *target_socket : output_socket->targets()) { + MFNode &target_node = target_socket->node(); + if (set_tag_and_check_if_modified(is_to_the_right[target_node.id()], true)) { + nodes_to_check.push(&target_node); + } + } + } + } + + return is_to_the_right; +} + +static Vector<MFNode *> find_nodes_based_on_mask(MFNetwork &network, + Span<bool> id_mask, + bool mask_value) +{ + Vector<MFNode *> nodes; + for (uint id : id_mask.index_range()) { + if (id_mask[id] == mask_value) { + MFNode *node = network.node_or_null_by_id(id); + if (node != nullptr) { + nodes.append(node); + } + } + } + return nodes; +} + +/** \} */ + +/* -------------------------------------------------------------------- */ +/** \name Dead Node Removal + * + * \{ */ + +/** + * Unused nodes are all those nodes that no dummy node depends upon. + */ +void dead_node_removal(MFNetwork &network) +{ + Array<bool> node_is_used_mask = mask_nodes_to_the_left(network, network.dummy_nodes()); + Vector<MFNode *> nodes_to_remove = find_nodes_based_on_mask(network, node_is_used_mask, false); + network.remove(nodes_to_remove); +} + +/** \} */ + +/* -------------------------------------------------------------------- */ +/** \name Constant Folding + * + * \{ */ + +static Vector<MFNode *> find_non_constant_nodes(MFNetwork &network) +{ + Vector<MFNode *> non_constant_nodes; + non_constant_nodes.extend(network.dummy_nodes()); + + for (MFFunctionNode *node : network.function_nodes()) { + if (!node->all_inputs_have_origin()) { + non_constant_nodes.append(node); + } + } + return non_constant_nodes; +} + +static bool output_has_non_constant_target_node(MFOutputSocket *output_socket, + Span<bool> is_not_constant_mask) +{ + for (MFInputSocket *target_socket : output_socket->targets()) { + MFNode &target_node = target_socket->node(); + bool target_is_not_constant = is_not_constant_mask[target_node.id()]; + if (target_is_not_constant) { + return true; + } + } + return false; +} + +static MFInputSocket *try_find_dummy_target_socket(MFOutputSocket *output_socket) +{ + for (MFInputSocket *target_socket : output_socket->targets()) { + if (target_socket->node().is_dummy()) { + return target_socket; + } + } + return nullptr; +} + +static Vector<MFInputSocket *> find_constant_inputs_to_fold( + MFNetwork &network, Vector<MFDummyNode *> &r_temporary_nodes) +{ + Vector<MFNode *> non_constant_nodes = find_non_constant_nodes(network); + Array<bool> is_not_constant_mask = mask_nodes_to_the_right(network, non_constant_nodes); + Vector<MFNode *> constant_nodes = find_nodes_based_on_mask(network, is_not_constant_mask, false); + + Vector<MFInputSocket *> sockets_to_compute; + for (MFNode *node : constant_nodes) { + if (node->inputs().size() == 0) { + continue; + } + + for (MFOutputSocket *output_socket : node->outputs()) { + MFDataType data_type = output_socket->data_type(); + if (output_has_non_constant_target_node(output_socket, is_not_constant_mask)) { + MFInputSocket *dummy_target = try_find_dummy_target_socket(output_socket); + if (dummy_target == nullptr) { + dummy_target = &network.add_output("Dummy", data_type); + network.add_link(*output_socket, *dummy_target); + r_temporary_nodes.append(&dummy_target->node().as_dummy()); + } + + sockets_to_compute.append(dummy_target); + } + } + } + return sockets_to_compute; +} + +static void prepare_params_for_constant_folding(const MultiFunction &network_fn, + MFParamsBuilder ¶ms, + ResourceCollector &resources) +{ + for (uint param_index : network_fn.param_indices()) { + MFParamType param_type = network_fn.param_type(param_index); + MFDataType data_type = param_type.data_type(); + + switch (data_type.category()) { + case MFDataType::Single: { + /* Allocates memory for a single constant folded value. */ + const CPPType &cpp_type = data_type.single_type(); + void *buffer = resources.linear_allocator().allocate(cpp_type.size(), + cpp_type.alignment()); + GMutableSpan array{cpp_type, buffer, 1}; + params.add_uninitialized_single_output(array); + break; + } + case MFDataType::Vector: { + /* Allocates memory for a constant folded vector. */ + const CPPType &cpp_type = data_type.vector_base_type(); + GVectorArray &vector_array = resources.construct<GVectorArray>(AT, cpp_type, 1); + params.add_vector_output(vector_array); + break; + } + } + } +} + +static Array<MFOutputSocket *> add_constant_folded_sockets(const MultiFunction &network_fn, + MFParamsBuilder ¶ms, + ResourceCollector &resources, + MFNetwork &network) +{ + Array<MFOutputSocket *> folded_sockets{network_fn.param_indices().size(), nullptr}; + + for (uint param_index : network_fn.param_indices()) { + MFParamType param_type = network_fn.param_type(param_index); + MFDataType data_type = param_type.data_type(); + + const MultiFunction *constant_fn = nullptr; + + switch (data_type.category()) { + case MFDataType::Single: { + const CPPType &cpp_type = data_type.single_type(); + GMutableSpan array = params.computed_array(param_index); + void *buffer = array.buffer(); + resources.add(buffer, array.type().destruct_cb(), AT); + + constant_fn = &resources.construct<CustomMF_GenericConstant>(AT, cpp_type, buffer); + break; + } + case MFDataType::Vector: { + GVectorArray &vector_array = params.computed_vector_array(param_index); + GSpan array = vector_array[0]; + constant_fn = &resources.construct<CustomMF_GenericConstantArray>(AT, array); + break; + } + } + + MFFunctionNode &folded_node = network.add_function(*constant_fn); + folded_sockets[param_index] = &folded_node.output(0); + } + return folded_sockets; +} + +static Array<MFOutputSocket *> compute_constant_sockets_and_add_folded_nodes( + MFNetwork &network, + Span<const MFInputSocket *> sockets_to_compute, + ResourceCollector &resources) +{ + MFNetworkEvaluator network_fn{{}, sockets_to_compute}; + + MFContextBuilder context; + MFParamsBuilder params{network_fn, 1}; + prepare_params_for_constant_folding(network_fn, params, resources); + network_fn.call({0}, params, context); + return add_constant_folded_sockets(network_fn, params, resources, network); +} + +/** + * Find function nodes that always output the same value and replace those with constant nodes. + */ +void constant_folding(MFNetwork &network, ResourceCollector &resources) +{ + Vector<MFDummyNode *> temporary_nodes; + Vector<MFInputSocket *> inputs_to_fold = find_constant_inputs_to_fold(network, temporary_nodes); + if (inputs_to_fold.size() == 0) { + return; + } + + Array<MFOutputSocket *> folded_sockets = compute_constant_sockets_and_add_folded_nodes( + network, inputs_to_fold, resources); + + for (uint i : inputs_to_fold.index_range()) { + MFOutputSocket &original_socket = *inputs_to_fold[i]->origin(); + network.relink(original_socket, *folded_sockets[i]); + } + + network.remove(temporary_nodes); +} + +/** \} */ + +/* -------------------------------------------------------------------- */ +/** \name Common Sub-network Elimination + * + * \{ */ + +static uint32_t compute_node_hash(MFFunctionNode &node, RNG *rng, Span<uint32_t> node_hashes) +{ + uint32_t combined_inputs_hash = 394659347u; + for (MFInputSocket *input_socket : node.inputs()) { + MFOutputSocket *origin_socket = input_socket->origin(); + uint32_t input_hash; + if (origin_socket == nullptr) { + input_hash = BLI_rng_get_uint(rng); + } + else { + input_hash = BLI_ghashutil_combine_hash(node_hashes[origin_socket->node().id()], + origin_socket->index()); + } + combined_inputs_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, input_hash); + } + + uint32_t function_hash = node.function().hash(); + uint32_t node_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, function_hash); + return node_hash; +} + +/** + * Produces a hash for every node. Two nodes with the same hash should have a high probability of + * outputting the same values. + */ +static Array<uint32_t> compute_node_hashes(MFNetwork &network) +{ + RNG *rng = BLI_rng_new(0); + Array<uint32_t> node_hashes(network.node_id_amount()); + Array<bool> node_is_hashed(network.node_id_amount(), false); + + /* No dummy nodes are not assumed to output the same values. */ + for (MFDummyNode *node : network.dummy_nodes()) { + uint32_t node_hash = BLI_rng_get_uint(rng); + node_hashes[node->id()] = node_hash; + node_is_hashed[node->id()] = true; + } + + Stack<MFFunctionNode *> nodes_to_check; + nodes_to_check.push_multiple(network.function_nodes()); + + while (!nodes_to_check.is_empty()) { + MFFunctionNode &node = *nodes_to_check.peek(); + if (node_is_hashed[node.id()]) { + nodes_to_check.pop(); + continue; + } + + /* Make sure that origin nodes are hashed first. */ + bool all_dependencies_ready = true; + for (MFInputSocket *input_socket : node.inputs()) { + MFOutputSocket *origin_socket = input_socket->origin(); + if (origin_socket != nullptr) { + MFNode &origin_node = origin_socket->node(); + if (!node_is_hashed[origin_node.id()]) { + all_dependencies_ready = false; + nodes_to_check.push(&origin_node.as_function()); + } + } + } + if (!all_dependencies_ready) { + continue; + } + + uint32_t node_hash = compute_node_hash(node, rng, node_hashes); + node_hashes[node.id()] = node_hash; + node_is_hashed[node.id()] = true; + nodes_to_check.pop(); + } + + BLI_rng_free(rng); + return node_hashes; +} + +static Map<uint32_t, Vector<MFNode *, 1>> group_nodes_by_hash(MFNetwork &network, + Span<uint32_t> node_hashes) +{ + Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash; + for (uint id : IndexRange(network.node_id_amount())) { + MFNode *node = network.node_or_null_by_id(id); + if (node != nullptr) { + uint32_t node_hash = node_hashes[id]; + nodes_by_hash.lookup_or_add_default(node_hash).append(node); + } + } + return nodes_by_hash; +} + +static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b) +{ + if (&a == &b) { + return true; + } + if (typeid(a) == typeid(b)) { + return a.equals(b); + } + return false; +} + +static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b) +{ + if (cache.in_same_set(a.id(), b.id())) { + return true; + } + + if (a.is_dummy() || b.is_dummy()) { + return false; + } + if (!functions_are_equal(a.as_function().function(), b.as_function().function())) { + return false; + } + for (uint i : a.inputs().index_range()) { + const MFOutputSocket *origin_a = a.input(i).origin(); + const MFOutputSocket *origin_b = b.input(i).origin(); + if (origin_a == nullptr || origin_b == nullptr) { + return false; + } + if (!nodes_output_same_values(cache, origin_a->node(), origin_b->node())) { + return false; + } + } + + cache.join(a.id(), b.id()); + return true; +} + +static void relink_duplicate_nodes(MFNetwork &network, + Map<uint32_t, Vector<MFNode *, 1>> &nodes_by_hash) +{ + DisjointSet same_node_cache{network.node_id_amount()}; + + for (Span<MFNode *> nodes_with_same_hash : nodes_by_hash.values()) { + if (nodes_with_same_hash.size() <= 1) { + continue; + } + + Vector<MFNode *, 16> nodes_to_check = nodes_with_same_hash; + while (nodes_to_check.size() >= 2) { + Vector<MFNode *, 16> remaining_nodes; + + MFNode &deduplicated_node = *nodes_to_check[0]; + for (MFNode *node : nodes_to_check.as_span().drop_front(1)) { + /* This is true with fairly high probability, but hash collisions can happen. So we have to + * check if the node actually output the same values. */ + if (nodes_output_same_values(same_node_cache, deduplicated_node, *node)) { + for (uint i : deduplicated_node.outputs().index_range()) { + network.relink(node->output(i), deduplicated_node.output(i)); + } + } + else { + remaining_nodes.append(node); + } + } + nodes_to_check = std::move(remaining_nodes); + } + } +} + +/** + * Tries to detect duplicate sub-networks and eliminates them. This can help quite a lot when node + * groups were used to create the network. + */ +void common_subnetwork_elimination(MFNetwork &network) +{ + Array<uint32_t> node_hashes = compute_node_hashes(network); + Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash = group_nodes_by_hash(network, node_hashes); + relink_duplicate_nodes(network, nodes_by_hash); +} + +/** \} */ + +} // namespace blender::fn::mf_network_optimization |