diff options
author | soumith <soumith@fb.com> | 2016-11-01 04:22:49 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-11-01 04:22:49 +0300 |
commit | 6f0ed600c788252ffec0b7c3b1017213342bec9a (patch) | |
tree | e3d3302741789822c1122b6536aa8f9b0e07d8d5 | |
parent | deb77aec69bd0f206c9fd40188c46e625121a06d (diff) |
implement torch.nonzero
-rw-r--r-- | TensorMath.lua | 10 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.cu | 92 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 66 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 2 | ||||
-rw-r--r-- | test/test.lua | 23 |
5 files changed, 191 insertions, 2 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index f6803d1..6163925 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -886,6 +886,11 @@ for k, Tensor_ in pairs(handledTypenames) do {name=Tensor .. "Array"}, {name="index", default=lastdimarray(2)}}) + wrap("nonzero", + cname("nonzero"), + {{name="CudaLongTensor", default=true, returned=true}, + {name=Tensor}}) + if real == 'float' or real == 'double' or real == 'half' then for _,name in ipairs({"log", "log1p", "exp", "cos", "acos", "cosh", @@ -1620,6 +1625,11 @@ wrap("cat", {name=Tensor .. "Array"}, {name="index", default=lastdimarray(2)}}) +wrap("nonzero", + cname("nonzero"), + {{name="CudaLongTensor", default=true, returned=true}, + {name=Tensor}}) + for _,f in ipairs({{name='geometric'}, {name='bernoulli', a=0.5}}) do diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu index f0bbd9c..8d3d95e 100644 --- a/lib/THC/THCTensorMath.cu +++ b/lib/THC/THCTensorMath.cu @@ -4,6 +4,18 @@ #include "THCApply.cuh" #include "THCNumerics.cuh" +#include <thrust/copy.h> +#include <thrust/count.h> +#include <thrust/device_ptr.h> +#include <thrust/device_vector.h> +#include <thrust/execution_policy.h> +#include <thrust/functional.h> +#include <thrust/sequence.h> +#include <thrust/iterator/transform_iterator.h> +#include <thrust/transform.h> +#if CUDA_VERSION >= 7000 +#include <thrust/system/cuda/execution_policy.h> +#endif #include <cfloat> template <typename T> @@ -14,5 +26,85 @@ struct TensorFillOp { const T val; }; +// copypasta from https://github.com/thrust/thrust/blob/master/examples/strided_range.cu +template <typename Iterator> +class strided_range +{ + public: + + typedef typename thrust::iterator_difference<Iterator>::type difference_type; + + struct stride_functor : public thrust::unary_function<difference_type, + difference_type> + { + difference_type stride; + + stride_functor(difference_type stride) + : stride(stride) {} + + __host__ __device__ + difference_type operator()(const difference_type& i) const + { + return stride * i; + } + }; + + typedef typename thrust::counting_iterator<difference_type> CountingIterator; + typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator; + typedef typename thrust::permutation_iterator<Iterator,TransformIterator> PermutationIterator; + + // type of the strided_range iterator + typedef PermutationIterator iterator; + + // construct strided_range for the range [first,last) + strided_range(Iterator first, Iterator last, difference_type stride) + : first(first), last(last), stride(stride) {} + + iterator begin(void) const + { + return PermutationIterator(first, + TransformIterator(CountingIterator(0), + stride_functor(stride))); + } + + iterator end(void) const + { + return begin() + ((last - first) + (stride - 1)) / stride; + } + + protected: + Iterator first; + Iterator last; + difference_type stride; +}; + +struct idx_functor +{ + long div; + long size; + + __host__ __device__ + idx_functor(long div, long size) : div(div), size(size) {} + + __host__ __device__ + long operator()(long val) { + return (val / div) % size + 1; + } +}; + +template <typename T> +struct NonZeroOp +{ + NonZeroOp() {} + __host__ __device__ bool operator()(T lhs) const { + if (THCNumerics<T>::ne(lhs, ScalarConvert<float, T>::to(0.0))) { + return true; + } else { + return false; + } + } +}; + + #include "generic/THCTensorMath.cu" #include "THCGenerateAllTypes.h" diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 231695d..67243cf 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -136,4 +136,70 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, } } +void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, + THCTensor *self) +{ + THAssert(THCTensor_(checkGPU)(state, 1, self )); + THAssert(THCudaLongTensor_checkGPU(state, 1, tensor)); + + using namespace thrust::placeholders; + + self = THCTensor_(newContiguous)(state, self); + thrust::device_ptr<real> self_data(THCTensor_(data)(state, self)); + + int num_dim = THCTensor_(nDimension)(state, self); + long N = THCTensor_(nElement)(state, self); + + THCudaLongTensor_resize2d(state, tensor, N, num_dim); + tensor = THCudaLongTensor_newContiguous(state, tensor); + thrust::device_ptr<long> tensor_data(THCudaLongTensor_data(state, tensor)); + + thrust::counting_iterator<long> idxfirst(0); + thrust::counting_iterator<long> idxlast = idxfirst + N; + + typedef thrust::device_ptr<long> Iter; + strided_range<Iter> strided_tensor(tensor_data, + tensor_data+N*num_dim, num_dim); + +#if CUDA_VERSION >= 7000 + cudaStream_t stream = THCState_getCurrentStream(state); +#endif + + strided_range<Iter>::iterator dend = thrust::copy_if( +#if CUDA_VERSION >= 7000 + thrust::cuda::par.on(stream), +#endif + idxfirst, + idxlast, + self_data, + strided_tensor.begin(), + NonZeroOp<real>() + ); + + long num_nonzeros = thrust::distance(strided_tensor.begin(), dend); + + long div = 1; + for (int dim = num_dim-1; dim >= 0; dim--) { + strided_range<Iter> stride_dim(tensor_data+dim, + tensor_data+N*num_dim, num_dim); + thrust::transform( +#if CUDA_VERSION >= 7000 + thrust::cuda::par.on(stream), +#endif + strided_tensor.begin(), + strided_tensor.end(), + stride_dim.begin(), + idx_functor(div, self->size[dim]) + ); + div *= self->size[dim]; + } + + THCudaLongTensor_resize2d(state, tensor, num_nonzeros, num_dim); + + THCTensor_(free)(state, self); + THCudaLongTensor_free(state, tensor); + + THCudaCheck(cudaGetLastError()); +} + #endif diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 5f4f8ee..0335a62 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -11,7 +11,7 @@ THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, T THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); +THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); - #endif diff --git a/test/test.lua b/test/test.lua index 058103d..b5eef6b 100644 --- a/test/test.lua +++ b/test/test.lua @@ -896,6 +896,27 @@ function test.cpow() checkMultiDevice(x, 'cpow', y) end +function test.nonzero() + local minsize = 10 + local maxsize = 20 + local dims = {chooseInt(minsize, maxsize)} + local threshold = 1 / 3 + local flip = math.random() + while flip > threshold do + dims[#dims + 1] = chooseInt(minsize, maxsize) + flip = math.random() + end + local x = createTestTensorWithSizes(true, true, dims) + local randMask = torch.ByteTensor(unpack(dims)):bernoulli() + x:maskedFill(randMask, 0) + for k, typename in ipairs(typenames) do + local ctype = t2cpu[typename] + local x = x:type(ctype) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'nonzero') + end + checkMultiDevice(x, 'nonzero') +end + function test.cdiv() local sz1 = chooseInt(minsize, maxsize) local sz2 = chooseInt(minsize, maxsize) @@ -3261,7 +3282,7 @@ function test.cat() end function test.catArray() - for k, typename in ipairs(typenames) do + for k, typename in ipairs(typenames) do for dim = 1, 3 do local x = torch.Tensor(13, minsize, minsize):uniform() :type(typename):transpose(1, dim) |