From 9964eed9ac7547db4c58bf5eabb786440236b138 Mon Sep 17 00:00:00 2001 From: Campbell Barton Date: Sun, 6 Dec 2015 21:33:39 +1100 Subject: PyAPI: add optional filter argument to KDTree.find --- tests/python/bl_pyapi_mathutils.py | 72 +++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) (limited to 'tests/python') diff --git a/tests/python/bl_pyapi_mathutils.py b/tests/python/bl_pyapi_mathutils.py index b7f61df0e40..7761b6cb7b1 100644 --- a/tests/python/bl_pyapi_mathutils.py +++ b/tests/python/bl_pyapi_mathutils.py @@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase): class KDTreeTesting(unittest.TestCase): - @staticmethod - def kdtree_create_grid_3d(tot): - k = kdtree.KDTree(tot * tot * tot) + def kdtree_create_grid_3d_data(tot): index = 0 mul = 1.0 / (tot - 1) for x in range(tot): for y in range(tot): for z in range(tot): - k.insert((x * mul, y * mul, z * mul), index) + yield (x * mul, y * mul, z * mul), index index += 1 + + @staticmethod + def kdtree_create_grid_3d(tot, *, filter_fn=None): + k = kdtree.KDTree(tot * tot * tot) + for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot): + if (filter_fn is not None) and (not filter_fn(co, index)): + continue + k.insert(co, index) k.balance() return k @@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase): ret = k.find_n((1.0,) * 3, tot) self.assertEqual(len(ret), tot) + def test_kdtree_grid_filter_simple(self): + size = 10 + k = self.kdtree_create_grid_3d(size) + + # filter exact index + ret_regular = k.find((1.0,) * 3) + ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1]) + self.assertEqual(ret_regular, ret_filter) + ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1]) + self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance + + def test_kdtree_grid_filter_pairs(self): + size = 10 + k_all = self.kdtree_create_grid_3d(size) + k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1) + k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0) + + samples = 5 + mul = 1 / (samples - 1) + for x in range(samples): + for y in range(samples): + for z in range(samples): + co = (x * mul, y * mul, z * mul) + + ret_regular = k_odd.find(co) + self.assertEqual(ret_regular[1] % 2, 1) + ret_filter = k_all.find(co, lambda i: (i % 2) == 1) + self.assertEqual(ret_regular, ret_filter) + + ret_regular = k_evn.find(co) + self.assertEqual(ret_regular[1] % 2, 0) + ret_filter = k_all.find(co, lambda i: (i % 2) == 0) + self.assertEqual(ret_regular, ret_filter) + + + # filter out all values (search odd tree for even values and the reverse) + co = (0,) * 3 + ret_filter = k_odd.find(co, lambda i: (i % 2) == 0) + self.assertEqual(ret_filter[1], None) + + ret_filter = k_evn.find(co, lambda i: (i % 2) == 1) + self.assertEqual(ret_filter[1], None) + def test_kdtree_invalid_size(self): with self.assertRaises(ValueError): kdtree.KDTree(-1) @@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase): with self.assertRaises(RuntimeError): k.find(co) + def test_kdtree_invalid_filter(self): + k = kdtree.KDTree(1) + k.insert((0,) * 3, 0) + k.balance() + # not callable + with self.assertRaises(TypeError): + k.find((0,) * 3, filter=None) + # no args + with self.assertRaises(TypeError): + k.find((0,) * 3, filter=lambda: None) + # bad return value + with self.assertRaises(ValueError): + k.find((0,) * 3, filter=lambda i: None) + + if __name__ == '__main__': import sys sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else []) -- cgit v1.2.3