diff options
Diffstat (limited to 'source/blender/nodes/geometry/nodes/node_geo_input_normal.cc')
-rw-r--r-- | source/blender/nodes/geometry/nodes/node_geo_input_normal.cc | 45 |
1 files changed, 18 insertions, 27 deletions
diff --git a/source/blender/nodes/geometry/nodes/node_geo_input_normal.cc b/source/blender/nodes/geometry/nodes/node_geo_input_normal.cc index 6e63354a31a..1cc508d9d9d 100644 --- a/source/blender/nodes/geometry/nodes/node_geo_input_normal.cc +++ b/source/blender/nodes/geometry/nodes/node_geo_input_normal.cc @@ -93,8 +93,7 @@ static VArray<float3> mesh_vertex_normals(const Mesh &mesh, static VArray<float3> construct_mesh_normals_gvarray(const MeshComponent &mesh_component, const Mesh &mesh, const IndexMask mask, - const AttributeDomain domain, - ResourceScope &UNUSED(scope)) + const AttributeDomain domain) { Span<MVert> verts{mesh.mvert, mesh.totvert}; Span<MEdge> edges{mesh.medge, mesh.totedge}; @@ -199,8 +198,7 @@ static Array<float3> curve_normal_point_domain(const CurveEval &curve) } static VArray<float3> construct_curve_normal_gvarray(const CurveComponent &component, - const AttributeDomain domain, - ResourceScope &UNUSED(scope)) + const AttributeDomain domain) { const CurveEval *curve = component.get_for_read(); if (curve == nullptr) { @@ -231,36 +229,29 @@ static VArray<float3> construct_curve_normal_gvarray(const CurveComponent &compo return nullptr; } -class NormalFieldInput final : public fn::FieldInput { +class NormalFieldInput final : public GeometryFieldInput { public: - NormalFieldInput() : fn::FieldInput(CPPType::get<float3>(), "Normal node") + NormalFieldInput() : GeometryFieldInput(CPPType::get<float3>(), "Normal node") { category_ = Category::Generated; } - GVArray get_varray_for_context(const fn::FieldContext &context, - IndexMask mask, - ResourceScope &scope) const final + GVArray get_varray_for_context(const GeometryComponent &component, + const AttributeDomain domain, + IndexMask mask) const final { - if (const GeometryComponentFieldContext *geometry_context = - dynamic_cast<const GeometryComponentFieldContext *>(&context)) { - - const GeometryComponent &component = geometry_context->geometry_component(); - const AttributeDomain domain = geometry_context->domain(); - - if (component.type() == GEO_COMPONENT_TYPE_MESH) { - const MeshComponent &mesh_component = static_cast<const MeshComponent &>(component); - const Mesh *mesh = mesh_component.get_for_read(); - if (mesh == nullptr) { - return {}; - } - - return construct_mesh_normals_gvarray(mesh_component, *mesh, mask, domain, scope); - } - if (component.type() == GEO_COMPONENT_TYPE_CURVE) { - const CurveComponent &curve_component = static_cast<const CurveComponent &>(component); - return construct_curve_normal_gvarray(curve_component, domain, scope); + if (component.type() == GEO_COMPONENT_TYPE_MESH) { + const MeshComponent &mesh_component = static_cast<const MeshComponent &>(component); + const Mesh *mesh = mesh_component.get_for_read(); + if (mesh == nullptr) { + return {}; } + + return construct_mesh_normals_gvarray(mesh_component, *mesh, mask, domain); + } + if (component.type() == GEO_COMPONENT_TYPE_CURVE) { + const CurveComponent &curve_component = static_cast<const CurveComponent &>(component); + return construct_curve_normal_gvarray(curve_component, domain); } return {}; } |