Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-11-01 04:22:49 +0300
committersoumith <soumith@fb.com>2016-11-01 04:22:49 +0300
commit6f0ed600c788252ffec0b7c3b1017213342bec9a (patch)
treee3d3302741789822c1122b6536aa8f9b0e07d8d5
parentdeb77aec69bd0f206c9fd40188c46e625121a06d (diff)
implement torch.nonzero
-rw-r--r--TensorMath.lua10
-rw-r--r--lib/THC/THCTensorMath.cu92
-rw-r--r--lib/THC/generic/THCTensorMath.cu66
-rw-r--r--lib/THC/generic/THCTensorMath.h2
-rw-r--r--test/test.lua23
5 files changed, 191 insertions, 2 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index f6803d1..6163925 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -886,6 +886,11 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor .. "Array"},
{name="index", default=lastdimarray(2)}})
+ wrap("nonzero",
+ cname("nonzero"),
+ {{name="CudaLongTensor", default=true, returned=true},
+ {name=Tensor}})
+
if real == 'float' or real == 'double' or real == 'half' then
for _,name in ipairs({"log", "log1p", "exp",
"cos", "acos", "cosh",
@@ -1620,6 +1625,11 @@ wrap("cat",
{name=Tensor .. "Array"},
{name="index", default=lastdimarray(2)}})
+wrap("nonzero",
+ cname("nonzero"),
+ {{name="CudaLongTensor", default=true, returned=true},
+ {name=Tensor}})
+
for _,f in ipairs({{name='geometric'},
{name='bernoulli', a=0.5}}) do
diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu
index f0bbd9c..8d3d95e 100644
--- a/lib/THC/THCTensorMath.cu
+++ b/lib/THC/THCTensorMath.cu
@@ -4,6 +4,18 @@
#include "THCApply.cuh"
#include "THCNumerics.cuh"
+#include <thrust/copy.h>
+#include <thrust/count.h>
+#include <thrust/device_ptr.h>
+#include <thrust/device_vector.h>
+#include <thrust/execution_policy.h>
+#include <thrust/functional.h>
+#include <thrust/sequence.h>
+#include <thrust/iterator/transform_iterator.h>
+#include <thrust/transform.h>
+#if CUDA_VERSION >= 7000
+#include <thrust/system/cuda/execution_policy.h>
+#endif
#include <cfloat>
template <typename T>
@@ -14,5 +26,85 @@ struct TensorFillOp {
const T val;
};
+// copypasta from https://github.com/thrust/thrust/blob/master/examples/strided_range.cu
+template <typename Iterator>
+class strided_range
+{
+ public:
+
+ typedef typename thrust::iterator_difference<Iterator>::type difference_type;
+
+ struct stride_functor : public thrust::unary_function<difference_type,
+ difference_type>
+ {
+ difference_type stride;
+
+ stride_functor(difference_type stride)
+ : stride(stride) {}
+
+ __host__ __device__
+ difference_type operator()(const difference_type& i) const
+ {
+ return stride * i;
+ }
+ };
+
+ typedef typename thrust::counting_iterator<difference_type> CountingIterator;
+ typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator;
+ typedef typename thrust::permutation_iterator<Iterator,TransformIterator> PermutationIterator;
+
+ // type of the strided_range iterator
+ typedef PermutationIterator iterator;
+
+ // construct strided_range for the range [first,last)
+ strided_range(Iterator first, Iterator last, difference_type stride)
+ : first(first), last(last), stride(stride) {}
+
+ iterator begin(void) const
+ {
+ return PermutationIterator(first,
+ TransformIterator(CountingIterator(0),
+ stride_functor(stride)));
+ }
+
+ iterator end(void) const
+ {
+ return begin() + ((last - first) + (stride - 1)) / stride;
+ }
+
+ protected:
+ Iterator first;
+ Iterator last;
+ difference_type stride;
+};
+
+struct idx_functor
+{
+ long div;
+ long size;
+
+ __host__ __device__
+ idx_functor(long div, long size) : div(div), size(size) {}
+
+ __host__ __device__
+ long operator()(long val) {
+ return (val / div) % size + 1;
+ }
+};
+
+template <typename T>
+struct NonZeroOp
+{
+ NonZeroOp() {}
+ __host__ __device__ bool operator()(T lhs) const {
+ if (THCNumerics<T>::ne(lhs, ScalarConvert<float, T>::to(0.0))) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+};
+
+
#include "generic/THCTensorMath.cu"
#include "THCGenerateAllTypes.h"
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu
index 231695d..67243cf 100644
--- a/lib/THC/generic/THCTensorMath.cu
+++ b/lib/THC/generic/THCTensorMath.cu
@@ -136,4 +136,70 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
}
}
+void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
+ THCTensor *self)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, self ));
+ THAssert(THCudaLongTensor_checkGPU(state, 1, tensor));
+
+ using namespace thrust::placeholders;
+
+ self = THCTensor_(newContiguous)(state, self);
+ thrust::device_ptr<real> self_data(THCTensor_(data)(state, self));
+
+ int num_dim = THCTensor_(nDimension)(state, self);
+ long N = THCTensor_(nElement)(state, self);
+
+ THCudaLongTensor_resize2d(state, tensor, N, num_dim);
+ tensor = THCudaLongTensor_newContiguous(state, tensor);
+ thrust::device_ptr<long> tensor_data(THCudaLongTensor_data(state, tensor));
+
+ thrust::counting_iterator<long> idxfirst(0);
+ thrust::counting_iterator<long> idxlast = idxfirst + N;
+
+ typedef thrust::device_ptr<long> Iter;
+ strided_range<Iter> strided_tensor(tensor_data,
+ tensor_data+N*num_dim, num_dim);
+
+#if CUDA_VERSION >= 7000
+ cudaStream_t stream = THCState_getCurrentStream(state);
+#endif
+
+ strided_range<Iter>::iterator dend = thrust::copy_if(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(stream),
+#endif
+ idxfirst,
+ idxlast,
+ self_data,
+ strided_tensor.begin(),
+ NonZeroOp<real>()
+ );
+
+ long num_nonzeros = thrust::distance(strided_tensor.begin(), dend);
+
+ long div = 1;
+ for (int dim = num_dim-1; dim >= 0; dim--) {
+ strided_range<Iter> stride_dim(tensor_data+dim,
+ tensor_data+N*num_dim, num_dim);
+ thrust::transform(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(stream),
+#endif
+ strided_tensor.begin(),
+ strided_tensor.end(),
+ stride_dim.begin(),
+ idx_functor(div, self->size[dim])
+ );
+ div *= self->size[dim];
+ }
+
+ THCudaLongTensor_resize2d(state, tensor, num_nonzeros, num_dim);
+
+ THCTensor_(free)(state, self);
+ THCudaLongTensor_free(state, tensor);
+
+ THCudaCheck(cudaGetLastError());
+}
+
#endif
diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h
index 5f4f8ee..0335a62 100644
--- a/lib/THC/generic/THCTensorMath.h
+++ b/lib/THC/generic/THCTensorMath.h
@@ -11,7 +11,7 @@ THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, T
THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
+THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
-
#endif
diff --git a/test/test.lua b/test/test.lua
index 058103d..b5eef6b 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -896,6 +896,27 @@ function test.cpow()
checkMultiDevice(x, 'cpow', y)
end
+function test.nonzero()
+ local minsize = 10
+ local maxsize = 20
+ local dims = {chooseInt(minsize, maxsize)}
+ local threshold = 1 / 3
+ local flip = math.random()
+ while flip > threshold do
+ dims[#dims + 1] = chooseInt(minsize, maxsize)
+ flip = math.random()
+ end
+ local x = createTestTensorWithSizes(true, true, dims)
+ local randMask = torch.ByteTensor(unpack(dims)):bernoulli()
+ x:maskedFill(randMask, 0)
+ for k, typename in ipairs(typenames) do
+ local ctype = t2cpu[typename]
+ local x = x:type(ctype)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'nonzero')
+ end
+ checkMultiDevice(x, 'nonzero')
+end
+
function test.cdiv()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
@@ -3261,7 +3282,7 @@ function test.cat()
end
function test.catArray()
- for k, typename in ipairs(typenames) do
+ for k, typename in ipairs(typenames) do
for dim = 1, 3 do
local x = torch.Tensor(13, minsize, minsize):uniform()
:type(typename):transpose(1, dim)