diff options
author | Campbell Barton <ideasman42@gmail.com> | 2015-12-06 13:33:39 +0300 |
---|---|---|
committer | Campbell Barton <ideasman42@gmail.com> | 2015-12-06 13:35:32 +0300 |
commit | 9964eed9ac7547db4c58bf5eabb786440236b138 (patch) | |
tree | 09d17c49983f6df79e7d271f5be17587e598f13c | |
parent | 54b95c30ae8300c7633c0526f7ab5da310c5f93a (diff) |
PyAPI: add optional filter argument to KDTree.find
-rw-r--r-- | source/blender/python/mathutils/mathutils_kdtree.c | 61 | ||||
-rw-r--r-- | tests/python/bl_pyapi_mathutils.py | 72 |
2 files changed, 122 insertions, 11 deletions
diff --git a/source/blender/python/mathutils/mathutils_kdtree.c b/source/blender/python/mathutils/mathutils_kdtree.c index dc1e82a744b..ca66c1906b4 100644 --- a/source/blender/python/mathutils/mathutils_kdtree.c +++ b/source/blender/python/mathutils/mathutils_kdtree.c @@ -189,26 +189,57 @@ static PyObject *py_kdtree_balance(PyKDTree *self) Py_RETURN_NONE; } +struct PyKDTree_NearestData { + PyObject *py_filter; + bool is_error; +}; + +static int py_find_nearest_cb(void *user_data, int index, const float co[3], float dist_sq) +{ + UNUSED_VARS(co, dist_sq); + + struct PyKDTree_NearestData *data = user_data; + + PyObject *py_args = PyTuple_New(1); + PyTuple_SET_ITEM(py_args, 0, PyLong_FromLong(index)); + PyObject *result = PyObject_CallObject(data->py_filter, py_args); + Py_DECREF(py_args); + + if (result) { + bool use_node; + int ok = PyC_ParseBool(result, &use_node); + Py_DECREF(result); + if (ok) { + return (int)use_node; + } + } + + data->is_error = true; + return -1; +} + PyDoc_STRVAR(py_kdtree_find_doc, -".. method:: find(co)\n" +".. method:: find(co, filter=None)\n" "\n" " Find nearest point to ``co``.\n" "\n" " :arg co: 3d coordinates.\n" " :type co: float triplet\n" +" :arg filter: function which takes an index and returns True for indices to include in the search.\n" +" :type filter: callable\n" " :return: Returns (:class:`Vector`, index, distance).\n" " :rtype: :class:`tuple`\n" ); static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs) { - PyObject *py_co; + PyObject *py_co, *py_filter = NULL; float co[3]; KDTreeNearest nearest; - const char *keywords[] = {"co", NULL}; + const char *keywords[] = {"co", "filter", NULL}; if (!PyArg_ParseTupleAndKeywords( - args, kwargs, (char *) "O:find", (char **)keywords, - &py_co)) + args, kwargs, (char *) "O|O:find", (char **)keywords, + &py_co, &py_filter)) { return NULL; } @@ -221,10 +252,26 @@ static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs return NULL; } - nearest.index = -1; - BLI_kdtree_find_nearest(self->obj, co, &nearest); + if (py_filter == NULL) { + BLI_kdtree_find_nearest(self->obj, co, &nearest); + } + else { + struct PyKDTree_NearestData data = {0}; + + data.py_filter = py_filter; + data.is_error = false; + + BLI_kdtree_find_nearest_cb( + self->obj, co, + py_find_nearest_cb, &data, + &nearest); + + if (data.is_error) { + return NULL; + } + } return kdtree_nearest_to_py_and_check(&nearest); } diff --git a/tests/python/bl_pyapi_mathutils.py b/tests/python/bl_pyapi_mathutils.py index b7f61df0e40..7761b6cb7b1 100644 --- a/tests/python/bl_pyapi_mathutils.py +++ b/tests/python/bl_pyapi_mathutils.py @@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase): class KDTreeTesting(unittest.TestCase): - @staticmethod - def kdtree_create_grid_3d(tot): - k = kdtree.KDTree(tot * tot * tot) + def kdtree_create_grid_3d_data(tot): index = 0 mul = 1.0 / (tot - 1) for x in range(tot): for y in range(tot): for z in range(tot): - k.insert((x * mul, y * mul, z * mul), index) + yield (x * mul, y * mul, z * mul), index index += 1 + + @staticmethod + def kdtree_create_grid_3d(tot, *, filter_fn=None): + k = kdtree.KDTree(tot * tot * tot) + for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot): + if (filter_fn is not None) and (not filter_fn(co, index)): + continue + k.insert(co, index) k.balance() return k @@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase): ret = k.find_n((1.0,) * 3, tot) self.assertEqual(len(ret), tot) + def test_kdtree_grid_filter_simple(self): + size = 10 + k = self.kdtree_create_grid_3d(size) + + # filter exact index + ret_regular = k.find((1.0,) * 3) + ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1]) + self.assertEqual(ret_regular, ret_filter) + ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1]) + self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance + + def test_kdtree_grid_filter_pairs(self): + size = 10 + k_all = self.kdtree_create_grid_3d(size) + k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1) + k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0) + + samples = 5 + mul = 1 / (samples - 1) + for x in range(samples): + for y in range(samples): + for z in range(samples): + co = (x * mul, y * mul, z * mul) + + ret_regular = k_odd.find(co) + self.assertEqual(ret_regular[1] % 2, 1) + ret_filter = k_all.find(co, lambda i: (i % 2) == 1) + self.assertEqual(ret_regular, ret_filter) + + ret_regular = k_evn.find(co) + self.assertEqual(ret_regular[1] % 2, 0) + ret_filter = k_all.find(co, lambda i: (i % 2) == 0) + self.assertEqual(ret_regular, ret_filter) + + + # filter out all values (search odd tree for even values and the reverse) + co = (0,) * 3 + ret_filter = k_odd.find(co, lambda i: (i % 2) == 0) + self.assertEqual(ret_filter[1], None) + + ret_filter = k_evn.find(co, lambda i: (i % 2) == 1) + self.assertEqual(ret_filter[1], None) + def test_kdtree_invalid_size(self): with self.assertRaises(ValueError): kdtree.KDTree(-1) @@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase): with self.assertRaises(RuntimeError): k.find(co) + def test_kdtree_invalid_filter(self): + k = kdtree.KDTree(1) + k.insert((0,) * 3, 0) + k.balance() + # not callable + with self.assertRaises(TypeError): + k.find((0,) * 3, filter=None) + # no args + with self.assertRaises(TypeError): + k.find((0,) * 3, filter=lambda: None) + # bad return value + with self.assertRaises(ValueError): + k.find((0,) * 3, filter=lambda i: None) + + if __name__ == '__main__': import sys sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else []) |