Welcome to mirror list, hosted at ThFree Co, Russian Federation.

git.blender.org/blender.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCampbell Barton <ideasman42@gmail.com>2015-12-06 13:33:39 +0300
committerCampbell Barton <ideasman42@gmail.com>2015-12-06 13:35:32 +0300
commit9964eed9ac7547db4c58bf5eabb786440236b138 (patch)
tree09d17c49983f6df79e7d271f5be17587e598f13c
parent54b95c30ae8300c7633c0526f7ab5da310c5f93a (diff)
PyAPI: add optional filter argument to KDTree.find
-rw-r--r--source/blender/python/mathutils/mathutils_kdtree.c61
-rw-r--r--tests/python/bl_pyapi_mathutils.py72
2 files changed, 122 insertions, 11 deletions
diff --git a/source/blender/python/mathutils/mathutils_kdtree.c b/source/blender/python/mathutils/mathutils_kdtree.c
index dc1e82a744b..ca66c1906b4 100644
--- a/source/blender/python/mathutils/mathutils_kdtree.c
+++ b/source/blender/python/mathutils/mathutils_kdtree.c
@@ -189,26 +189,57 @@ static PyObject *py_kdtree_balance(PyKDTree *self)
Py_RETURN_NONE;
}
+struct PyKDTree_NearestData {
+ PyObject *py_filter;
+ bool is_error;
+};
+
+static int py_find_nearest_cb(void *user_data, int index, const float co[3], float dist_sq)
+{
+ UNUSED_VARS(co, dist_sq);
+
+ struct PyKDTree_NearestData *data = user_data;
+
+ PyObject *py_args = PyTuple_New(1);
+ PyTuple_SET_ITEM(py_args, 0, PyLong_FromLong(index));
+ PyObject *result = PyObject_CallObject(data->py_filter, py_args);
+ Py_DECREF(py_args);
+
+ if (result) {
+ bool use_node;
+ int ok = PyC_ParseBool(result, &use_node);
+ Py_DECREF(result);
+ if (ok) {
+ return (int)use_node;
+ }
+ }
+
+ data->is_error = true;
+ return -1;
+}
+
PyDoc_STRVAR(py_kdtree_find_doc,
-".. method:: find(co)\n"
+".. method:: find(co, filter=None)\n"
"\n"
" Find nearest point to ``co``.\n"
"\n"
" :arg co: 3d coordinates.\n"
" :type co: float triplet\n"
+" :arg filter: function which takes an index and returns True for indices to include in the search.\n"
+" :type filter: callable\n"
" :return: Returns (:class:`Vector`, index, distance).\n"
" :rtype: :class:`tuple`\n"
);
static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs)
{
- PyObject *py_co;
+ PyObject *py_co, *py_filter = NULL;
float co[3];
KDTreeNearest nearest;
- const char *keywords[] = {"co", NULL};
+ const char *keywords[] = {"co", "filter", NULL};
if (!PyArg_ParseTupleAndKeywords(
- args, kwargs, (char *) "O:find", (char **)keywords,
- &py_co))
+ args, kwargs, (char *) "O|O:find", (char **)keywords,
+ &py_co, &py_filter))
{
return NULL;
}
@@ -221,10 +252,26 @@ static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs
return NULL;
}
-
nearest.index = -1;
- BLI_kdtree_find_nearest(self->obj, co, &nearest);
+ if (py_filter == NULL) {
+ BLI_kdtree_find_nearest(self->obj, co, &nearest);
+ }
+ else {
+ struct PyKDTree_NearestData data = {0};
+
+ data.py_filter = py_filter;
+ data.is_error = false;
+
+ BLI_kdtree_find_nearest_cb(
+ self->obj, co,
+ py_find_nearest_cb, &data,
+ &nearest);
+
+ if (data.is_error) {
+ return NULL;
+ }
+ }
return kdtree_nearest_to_py_and_check(&nearest);
}
diff --git a/tests/python/bl_pyapi_mathutils.py b/tests/python/bl_pyapi_mathutils.py
index b7f61df0e40..7761b6cb7b1 100644
--- a/tests/python/bl_pyapi_mathutils.py
+++ b/tests/python/bl_pyapi_mathutils.py
@@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase):
class KDTreeTesting(unittest.TestCase):
-
@staticmethod
- def kdtree_create_grid_3d(tot):
- k = kdtree.KDTree(tot * tot * tot)
+ def kdtree_create_grid_3d_data(tot):
index = 0
mul = 1.0 / (tot - 1)
for x in range(tot):
for y in range(tot):
for z in range(tot):
- k.insert((x * mul, y * mul, z * mul), index)
+ yield (x * mul, y * mul, z * mul), index
index += 1
+
+ @staticmethod
+ def kdtree_create_grid_3d(tot, *, filter_fn=None):
+ k = kdtree.KDTree(tot * tot * tot)
+ for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot):
+ if (filter_fn is not None) and (not filter_fn(co, index)):
+ continue
+ k.insert(co, index)
k.balance()
return k
@@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase):
ret = k.find_n((1.0,) * 3, tot)
self.assertEqual(len(ret), tot)
+ def test_kdtree_grid_filter_simple(self):
+ size = 10
+ k = self.kdtree_create_grid_3d(size)
+
+ # filter exact index
+ ret_regular = k.find((1.0,) * 3)
+ ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1])
+ self.assertEqual(ret_regular, ret_filter)
+ ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1])
+ self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance
+
+ def test_kdtree_grid_filter_pairs(self):
+ size = 10
+ k_all = self.kdtree_create_grid_3d(size)
+ k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1)
+ k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0)
+
+ samples = 5
+ mul = 1 / (samples - 1)
+ for x in range(samples):
+ for y in range(samples):
+ for z in range(samples):
+ co = (x * mul, y * mul, z * mul)
+
+ ret_regular = k_odd.find(co)
+ self.assertEqual(ret_regular[1] % 2, 1)
+ ret_filter = k_all.find(co, lambda i: (i % 2) == 1)
+ self.assertEqual(ret_regular, ret_filter)
+
+ ret_regular = k_evn.find(co)
+ self.assertEqual(ret_regular[1] % 2, 0)
+ ret_filter = k_all.find(co, lambda i: (i % 2) == 0)
+ self.assertEqual(ret_regular, ret_filter)
+
+
+ # filter out all values (search odd tree for even values and the reverse)
+ co = (0,) * 3
+ ret_filter = k_odd.find(co, lambda i: (i % 2) == 0)
+ self.assertEqual(ret_filter[1], None)
+
+ ret_filter = k_evn.find(co, lambda i: (i % 2) == 1)
+ self.assertEqual(ret_filter[1], None)
+
def test_kdtree_invalid_size(self):
with self.assertRaises(ValueError):
kdtree.KDTree(-1)
@@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase):
with self.assertRaises(RuntimeError):
k.find(co)
+ def test_kdtree_invalid_filter(self):
+ k = kdtree.KDTree(1)
+ k.insert((0,) * 3, 0)
+ k.balance()
+ # not callable
+ with self.assertRaises(TypeError):
+ k.find((0,) * 3, filter=None)
+ # no args
+ with self.assertRaises(TypeError):
+ k.find((0,) * 3, filter=lambda: None)
+ # bad return value
+ with self.assertRaises(ValueError):
+ k.find((0,) * 3, filter=lambda i: None)
+
+
if __name__ == '__main__':
import sys
sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else [])