diff options
Diffstat (limited to 'source/blender/functions/intern/multi_function_network_optimization.cc')
-rw-r--r-- | source/blender/functions/intern/multi_function_network_optimization.cc | 31 |
1 files changed, 18 insertions, 13 deletions
diff --git a/source/blender/functions/intern/multi_function_network_optimization.cc b/source/blender/functions/intern/multi_function_network_optimization.cc index f1e047f01a1..e24f157d4f9 100644 --- a/source/blender/functions/intern/multi_function_network_optimization.cc +++ b/source/blender/functions/intern/multi_function_network_optimization.cc @@ -28,6 +28,7 @@ #include "BLI_disjoint_set.hh" #include "BLI_ghash.h" #include "BLI_map.hh" +#include "BLI_multi_value_map.hh" #include "BLI_rand.h" #include "BLI_stack.hh" @@ -44,9 +45,8 @@ static bool set_tag_and_check_if_modified(bool &tag, bool new_value) tag = new_value; return true; } - else { - return false; - } + + return false; } static Array<bool> mask_nodes_to_the_left(MFNetwork &network, Span<MFNode *> nodes) @@ -130,7 +130,8 @@ static Vector<MFNode *> find_nodes_based_on_mask(MFNetwork &network, */ void dead_node_removal(MFNetwork &network) { - Array<bool> node_is_used_mask = mask_nodes_to_the_left(network, network.dummy_nodes()); + Array<bool> node_is_used_mask = mask_nodes_to_the_left(network, + network.dummy_nodes().cast<MFNode *>()); Vector<MFNode *> nodes_to_remove = find_nodes_based_on_mask(network, node_is_used_mask, false); network.remove(nodes_to_remove); } @@ -156,7 +157,7 @@ static bool function_node_can_be_constant(MFFunctionNode *node) static Vector<MFNode *> find_non_constant_nodes(MFNetwork &network) { Vector<MFNode *> non_constant_nodes; - non_constant_nodes.extend(network.dummy_nodes()); + non_constant_nodes.extend(network.dummy_nodes().cast<MFNode *>()); for (MFFunctionNode *node : network.function_nodes()) { if (!function_node_can_be_constant(node)) { @@ -265,7 +266,7 @@ static Array<MFOutputSocket *> add_constant_folded_sockets(const MultiFunction & case MFDataType::Single: { const CPPType &cpp_type = data_type.single_type(); GMutableSpan array = params.computed_array(param_index); - void *buffer = array.buffer(); + void *buffer = array.data(); resources.add(buffer, array.type().destruct_cb(), AT); constant_fn = &resources.construct<CustomMF_GenericConstant>(AT, cpp_type, buffer); @@ -299,6 +300,10 @@ static Array<MFOutputSocket *> compute_constant_sockets_and_add_folded_nodes( return add_constant_folded_sockets(network_fn, params, resources, network); } +class MyClass { + MFDummyNode node; +}; + /** * Find function nodes that always output the same value and replace those with constant nodes. */ @@ -318,7 +323,7 @@ void constant_folding(MFNetwork &network, ResourceCollector &resources) network.relink(original_socket, *folded_sockets[i]); } - network.remove(temporary_nodes); + network.remove(temporary_nodes.as_span().cast<MFNode *>()); } /** \} */ @@ -403,15 +408,15 @@ static Array<uint64_t> compute_node_hashes(MFNetwork &network) return node_hashes; } -static Map<uint64_t, Vector<MFNode *, 1>> group_nodes_by_hash(MFNetwork &network, - Span<uint64_t> node_hashes) +static MultiValueMap<uint64_t, MFNode *> group_nodes_by_hash(MFNetwork &network, + Span<uint64_t> node_hashes) { - Map<uint64_t, Vector<MFNode *, 1>> nodes_by_hash; + MultiValueMap<uint64_t, MFNode *> nodes_by_hash; for (int id : IndexRange(network.node_id_amount())) { MFNode *node = network.node_or_null_by_id(id); if (node != nullptr) { uint64_t node_hash = node_hashes[id]; - nodes_by_hash.lookup_or_add_default(node_hash).append(node); + nodes_by_hash.add(node_hash, node); } } return nodes_by_hash; @@ -456,7 +461,7 @@ static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const } static void relink_duplicate_nodes(MFNetwork &network, - Map<uint64_t, Vector<MFNode *, 1>> &nodes_by_hash) + MultiValueMap<uint64_t, MFNode *> &nodes_by_hash) { DisjointSet same_node_cache{network.node_id_amount()}; @@ -494,7 +499,7 @@ static void relink_duplicate_nodes(MFNetwork &network, void common_subnetwork_elimination(MFNetwork &network) { Array<uint64_t> node_hashes = compute_node_hashes(network); - Map<uint64_t, Vector<MFNode *, 1>> nodes_by_hash = group_nodes_by_hash(network, node_hashes); + MultiValueMap<uint64_t, MFNode *> nodes_by_hash = group_nodes_by_hash(network, node_hashes); relink_duplicate_nodes(network, nodes_by_hash); } |