Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/generic')
-rw-r--r--lib/THC/generic/THCTensorMasked.cu191
-rw-r--r--lib/THC/generic/THCTensorMasked.h38
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu563
-rw-r--r--lib/THC/generic/THCTensorMathBlas.h13
-rw-r--r--lib/THC/generic/THCTensorMathCompare.cu101
-rw-r--r--lib/THC/generic/THCTensorMathCompare.h20
-rw-r--r--lib/THC/generic/THCTensorMathCompareT.cu113
-rw-r--r--lib/THC/generic/THCTensorMathCompareT.h19
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu135
-rw-r--r--lib/THC/generic/THCTensorMathReduce.h23
10 files changed, 1216 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensorMasked.cu b/lib/THC/generic/THCTensorMasked.cu
new file mode 100644
index 0000000..e6a5704
--- /dev/null
+++ b/lib/THC/generic/THCTensorMasked.cu
@@ -0,0 +1,191 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMasked.cu"
+#else
+
+
+THC_API void
+THCTensor_(maskedFill)(THCState* state,
+ THCTensor *tensor, THCudaByteTensor *mask, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, tensor, mask));
+ THArgCheck(THCTensor_(nElement)(state, tensor) ==
+ THCudaByteTensor_nElement(state, mask),
+ 2, "sizes do not match");
+
+ if (!THC_pointwiseApply2(state, tensor, mask,
+ TensorMaskedFillOp<real, unsigned char>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(maskedFillByte)(THCState* state,
+ THCTensor *tensor, THByteTensor *mask, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, tensor));
+ THLongStorage* maskSizes = THByteTensor_newSizeOf(mask);
+ THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, maskSizes, NULL);
+ THLongStorage_free(maskSizes);
+ THCudaByteTensor_copyByte(state, maskCuda, mask);
+ THCTensor_(maskedFill)(state, tensor, maskCuda, value);
+ THCudaByteTensor_free(state, maskCuda);
+}
+
+THC_API void
+THCTensor_(maskedCopy)(THCState* state,
+ THCTensor *tensor, THCudaByteTensor *mask, THCTensor *src)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, tensor, src, mask));
+ long maskSize = THCudaByteTensor_nElement(state, mask);
+ long tensorSize = THCTensor_(nElement)(state, tensor);
+ long srcSize = THCTensor_(nElement)(state, src);
+
+ // `mask` and `tensor` must have the same number of elements
+ THArgCheck(maskSize == tensorSize, 2,
+ "mask and tensor must have the same number of elements");
+
+ // Determine our output size
+ long totalElements = THCudaByteTensor_sumall(state, mask);
+
+ // The number of `1` elements present in the mask must be <= the
+ // number of elements available in `src`
+ if (totalElements > srcSize) {
+ THArgCheck(false, 2, "source nElements must be == mask `1` elements");
+ }
+
+ // FIXME: there appears to be a bug in Thrust (CUDA 7.0) for mixed
+ // iterator prefix sums? Convert `mask` to the same datatype as what
+ // we're accumulating the prefix sum in (long) to get around it
+ THCudaLongTensor* maskLong = THCudaLongTensor_new(state);
+ THLongStorage* maskSizes = THCudaByteTensor_newSizeOf(state, mask);
+ THCudaLongTensor_resize(state, maskLong, maskSizes, NULL);
+ THCudaLongTensor_copyCudaByte(state, maskLong, mask);
+
+ // Use a prefix sum to determine the output locations of the masked elements
+ THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state);
+ THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, NULL);
+ THLongStorage_free(maskSizes);
+
+ thrust::device_ptr<long>
+ maskData(THCudaLongTensor_data(state, maskLong));
+ thrust::device_ptr<long>
+ maskPrefixSumData(THCudaLongTensor_data(state, maskPrefixSum));
+
+ thrust::exclusive_scan(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ maskData,
+ maskData + THCudaLongTensor_nElement(state, maskLong),
+ maskPrefixSumData);
+
+ // We are getting elements from `src` based on an offset from
+ // `maskPrefixSum`, so that should be made contiguous too
+ THCTensor* contigSrc = THCTensor_(newContiguous)(state, src);
+
+ // update `tensor` where `mask` == 1 but pull from `src` at
+ // maskPrefixSum
+ bool status = THC_pointwiseApply3(
+ state, tensor, mask, maskPrefixSum,
+ TensorMaskedCopyOp<real, unsigned char, long>(
+ THCTensor_(data)(state, contigSrc)));
+
+ THCTensor_(free)(state, contigSrc);
+ THCudaLongTensor_free(state, maskLong);
+ THCudaLongTensor_free(state, maskPrefixSum);
+
+ THArgCheck(status, 2, CUTORCH_DIM_WARNING);
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(maskedCopyByte)(THCState* state,
+ THCTensor *tensor, THByteTensor *mask, THCTensor *src) {
+ THAssert(THCTensor_(checkGPU)(state, 2, tensor, src));
+ THLongStorage* maskSizes = THByteTensor_newSizeOf(mask);
+ THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, maskSizes, NULL);
+ THLongStorage_free(maskSizes);
+ THCudaByteTensor_copyByte(state, maskCuda, mask);
+ THCTensor_(maskedCopy)(state, tensor, maskCuda, src);
+ THCudaByteTensor_free(state, maskCuda);
+}
+
+THC_API void
+THCTensor_(maskedSelect)(THCState* state,
+ THCTensor* tensor, THCTensor* src, THCudaByteTensor* mask) {
+ THAssert(THCTensor_(checkGPU)(state, 3, tensor, src, mask));
+ THArgCheck(THCudaByteTensor_nElement(state, mask) ==
+ THCTensor_(nElement)(state, src),
+ 2, "sizes do not match");
+
+ // Determine our output size
+ long totalElements = THCudaByteTensor_sumall(state, mask);
+ THCTensor* tensorContig = THCTensor_(newContiguous)(state, tensor);
+
+ THCTensor_(resize1d)(state, tensorContig, totalElements);
+ if (tensor != tensorContig) {
+ THCTensor_(resize1d)(state, tensor, totalElements);
+ }
+
+ // FIXME: there appears to be a bug in Thrust (CUDA 7.0) for mixed
+ // iterator prefix sums? Convert `mask` to the same datatype as what
+ // we're accumulating the prefix sum in (long) to get around it
+ THCudaLongTensor* maskLong = THCudaLongTensor_new(state);
+ THLongStorage* maskSizes = THCudaByteTensor_newSizeOf(state, mask);
+ THCudaLongTensor_resize(state, maskLong, maskSizes, NULL);
+ THCudaLongTensor_copyCudaByte(state, maskLong, mask);
+
+ // Use a prefix sum to determine the output locations of the masked elements
+ THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state);
+ THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, NULL);
+ THLongStorage_free(maskSizes);
+
+ thrust::device_ptr<long>
+ maskData(THCudaLongTensor_data(state, maskLong));
+ thrust::device_ptr<long>
+ maskPrefixSumData(THCudaLongTensor_data(state, maskPrefixSum));
+
+ thrust::exclusive_scan(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ maskData,
+ maskData + THCudaLongTensor_nElement(state, maskLong),
+ maskPrefixSumData);
+
+ // Then copy over the masked elements at their desired output index
+ bool status = THC_pointwiseApply3(
+ state, mask, maskPrefixSum,
+ src, TensorMaskedSelectOp<real, unsigned char, long>(
+ THCTensor_(data)(state, tensor)));
+
+ THCudaLongTensor_free(state, maskLong);
+ THCudaLongTensor_free(state, maskPrefixSum);
+
+ if (tensor != tensorContig) {
+ THCTensor_(freeCopyTo)(state, tensorContig, tensor);
+ } else {
+ THCTensor_(free)(state, tensorContig);
+ }
+
+ THArgCheck(status, 2, CUTORCH_DIM_WARNING);
+ THCudaCheck(cudaGetLastError());
+}
+
+// FIXME: remove now that we have THCudaByteTensor?
+THC_API void
+THCTensor_(maskedSelectByte)(THCState* state,
+ THCTensor *tensor, THCTensor *src, THByteTensor *mask)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, tensor, src));
+ THLongStorage* maskSizes = THByteTensor_newSizeOf(mask);
+ THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, maskSizes, NULL);
+ THLongStorage_free(maskSizes);
+ THCudaByteTensor_copyByte(state, maskCuda, mask);
+ THCTensor_(maskedSelect)(state, tensor, src, maskCuda);
+ THCudaByteTensor_free(state, maskCuda);
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMasked.h b/lib/THC/generic/THCTensorMasked.h
new file mode 100644
index 0000000..98f5aee
--- /dev/null
+++ b/lib/THC/generic/THCTensorMasked.h
@@ -0,0 +1,38 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMasked.h"
+#else
+
+THC_API void THCTensor_(maskedFill)(THCState *state,
+ THCTensor *tensor,
+ THCudaByteTensor *mask,
+ real value);
+
+// FIXME: remove now that we have THCudaByteTensor?
+THC_API void THCTensor_(maskedFillByte)(THCState *state,
+ THCTensor *tensor,
+ THByteTensor *mask,
+ real value);
+
+THC_API void THCTensor_(maskedCopy)(THCState *state,
+ THCTensor *tensor,
+ THCudaByteTensor *mask,
+ THCTensor *src);
+
+// FIXME: remove now that we have THCudaByteTensor?
+THC_API void THCTensor_(maskedCopyByte)(THCState *state,
+ THCTensor *tensor,
+ THByteTensor *mask,
+ THCTensor *src);
+
+THC_API void THCTensor_(maskedSelect)(THCState *state,
+ THCTensor *tensor,
+ THCTensor *src,
+ THCudaByteTensor *mask);
+
+// FIXME: remove now that we have THCudaByteTensor?
+THC_API void THCTensor_(maskedSelectByte)(THCState *state,
+ THCTensor *tensor,
+ THCTensor *src,
+ THByteTensor *mask);
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
new file mode 100644
index 0000000..91e922c
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -0,0 +1,563 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathBlas.cu"
+#else
+
+THC_API real
+THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
+{
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ THArgCheck(THCTensor_(nElement)(state, self) ==
+ THCTensor_(nElement)(state, src), 2, "sizes do not match");
+
+ self = THCTensor_(newContiguous)(state, self);
+ src = THCTensor_(newContiguous)(state, src);
+
+#ifdef THC_REAL_IS_FLOAT
+ real result = THCudaBlas_Sdot(state,
+ THCTensor_(nElement)(state, self),
+ THCTensor_(data)(state, self), 1,
+ THCTensor_(data)(state, src), 1);
+#elif defined(THC_REAL_IS_DOUBLE)
+ real result = THCudaBlas_Ddot(state,
+ THCTensor_(nElement)(state, self),
+ THCTensor_(data)(state, self), 1,
+ THCTensor_(data)(state, src), 1);
+#endif
+
+ THCTensor_(free)(state, src);
+ THCTensor_(free)(state, self);
+ return result;
+
+#else
+ THError("unimplemented data type");
+ return ScalarConvert<int, real>::to(0);
+#endif
+}
+
+THC_API void
+THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec)
+{
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ THAssert(THCTensor_(checkGPU)(state, 4, r_, t, mat, vec));
+ if( (mat->nDimension != 2) || (vec->nDimension != 1) )
+ THError("matrix and vector expected");
+
+ if( mat->size[1] != vec->size[0] )
+ THError("size mismatch");
+
+ if(t->nDimension != 1)
+ THError("size mismatch");
+
+ if(t->size[0] != mat->size[0])
+ THError("size mismatch");
+
+ if(r_ != t)
+ {
+ THCTensor_(resizeAs)(state, r_, t);
+ THCTensor_(copy)(state, r_, t);
+ }
+
+ if(mat->stride[0] == 1)
+ {
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemv(state, 'n', mat->size[0], mat->size[1],
+ alpha, THCTensor_(data)(state, mat), mat->stride[1],
+ THCTensor_(data)(state, vec), vec->stride[0],
+ beta, THCTensor_(data)(state, r_), r_->stride[0]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemv(state, 'n', mat->size[0], mat->size[1],
+ alpha, THCTensor_(data)(state, mat), mat->stride[1],
+ THCTensor_(data)(state, vec), vec->stride[0],
+ beta, THCTensor_(data)(state, r_), r_->stride[0]);
+#endif
+ }
+ else if(mat->stride[1] == 1)
+ {
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemv(state, 't', mat->size[1], mat->size[0],
+ alpha, THCTensor_(data)(state, mat), mat->stride[0],
+ THCTensor_(data)(state, vec), vec->stride[0],
+ beta, THCTensor_(data)(state, r_), r_->stride[0]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemv(state, 't', mat->size[1], mat->size[0],
+ alpha, THCTensor_(data)(state, mat), mat->stride[0],
+ THCTensor_(data)(state, vec), vec->stride[0],
+ beta, THCTensor_(data)(state, r_), r_->stride[0]);
+#endif
+ }
+ else
+ {
+ THCTensor *cmat = THCTensor_(newContiguous)(state, mat);
+
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sgemv(state, 't', mat->size[1], mat->size[0],
+ alpha, THCTensor_(data)(state, cmat), cmat->stride[0],
+ THCTensor_(data)(state, vec), vec->stride[0],
+ beta, THCTensor_(data)(state, r_), r_->stride[0]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemv(state, 't', mat->size[1], mat->size[0],
+ alpha, THCTensor_(data)(state, cmat), cmat->stride[0],
+ THCTensor_(data)(state, vec), vec->stride[0],
+ beta, THCTensor_(data)(state, r_), r_->stride[0]);
+#endif
+
+ THCTensor_(free)(state, cmat);
+ }
+
+#else
+ THError("unimplemented data type");
+#endif
+}
+
+THC_API void
+THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2)
+{
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ THAssert(THCTensor_(checkGPU)(state, 4, r_, t, vec1, vec2));
+ if ( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) {
+ THError("vector and vector expected");
+ }
+
+ if (t->nDimension != 2) {
+ THError("size mismatch");
+ }
+
+ if ( (t->size[0] != vec1->size[0]) || (t->size[1] != vec2->size[0]) ) {
+ THError("size mismatch");
+ }
+
+ if (r_ != t) {
+ THCTensor_(resizeAs)(state, r_, t);
+ THCTensor_(copy)(state, r_, t);
+ }
+
+ if(beta != 1) {
+ THCTensor_(mul)(state, r_, r_, beta);
+ }
+
+ if(r_->stride[0] == 1)
+ {
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sger(state, vec1->size[0], vec2->size[0],
+ alpha, THCTensor_(data)(state, vec1), vec1->stride[0],
+ THCTensor_(data)(state, vec2), vec2->stride[0],
+ THCTensor_(data)(state, r_), r_->stride[1]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dger(state, vec1->size[0], vec2->size[0],
+ alpha, THCTensor_(data)(state, vec1), vec1->stride[0],
+ THCTensor_(data)(state, vec2), vec2->stride[0],
+ THCTensor_(data)(state, r_), r_->stride[1]);
+#endif
+ }
+ else if(r_->stride[1] == 1)
+ {
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sger(state, vec2->size[0], vec1->size[0],
+ alpha, THCTensor_(data)(state, vec2), vec2->stride[0],
+ THCTensor_(data)(state, vec1), vec1->stride[0],
+ THCTensor_(data)(state, r_), r_->stride[0]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dger(state, vec2->size[0], vec1->size[0],
+ alpha, THCTensor_(data)(state, vec2), vec2->stride[0],
+ THCTensor_(data)(state, vec1), vec1->stride[0],
+ THCTensor_(data)(state, r_), r_->stride[0]);
+#endif
+ }
+ else
+ {
+ THCTensor *cr = THCTensor_(newClone)(state, r_);
+
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_Sger(state, vec2->size[0], vec1->size[0],
+ alpha, THCTensor_(data)(state, vec2), vec2->stride[0],
+ THCTensor_(data)(state, vec1), vec1->stride[0],
+ THCTensor_(data)(state, cr), cr->stride[0]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dger(state, vec2->size[0], vec1->size[0],
+ alpha, THCTensor_(data)(state, vec2), vec2->stride[0],
+ THCTensor_(data)(state, vec1), vec1->stride[0],
+ THCTensor_(data)(state, cr), cr->stride[0]);
+#endif
+
+ THCTensor_(freeCopyTo)(state, cr, r_);
+ }
+#else
+ THError("unimplemented data type");
+#endif
+}
+
+THC_API void
+THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *m1, THCTensor *m2)
+{
+#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+
+ THAssert(THCTensor_(checkGPU)(state, 4, r_, t, m1, m2));
+ char transpose_r, transpose_m1, transpose_m2;
+ THCTensor *r__, *m1_, *m2_;
+
+ if( (m1->nDimension != 2) || (m2->nDimension != 2) )
+ THError("matrix and matrix expected");
+
+ if(t->nDimension != 2)
+ THError("size mismatch");
+
+ if( (t->size[0] != m1->size[0]) || (t->size[1] != m2->size[1]) || (m1->size[1] != m2->size[0]) )
+ THError("size mismatch");
+
+ if(t != r_)
+ {
+ THCTensor_(resizeAs)(state, r_, t);
+ THCTensor_(copy)(state, r_, t);
+ }
+
+ /* r_ */
+ if(r_->stride[0] == 1 &&
+ r_->stride[1] != 0)
+ {
+ transpose_r = 'n';
+ r__ = r_;
+ }
+ else if(r_->stride[1] == 1 &&
+ r_->stride[0] != 0)
+ {
+ THCTensor *swap = m2;
+ m2 = m1;
+ m1 = swap;
+ transpose_r = 't';
+ r__ = r_;
+ }
+ else
+ {
+ transpose_r = 'n';
+
+ THCTensor *transp_r_ = THCTensor_(newTranspose)(state, r_, 0, 1);
+ r__ = THCTensor_(newClone)(state, transp_r_);
+ THCTensor_(free)(state, transp_r_);
+ THCTensor_(transpose)(state, r__, NULL, 0, 1);
+ }
+
+ /* m1 */
+ if(m1->stride[(transpose_r == 'n' ? 0 : 1)] == 1 &&
+ m1->stride[(transpose_r == 'n' ? 1 : 0)] != 0)
+ {
+ transpose_m1 = 'n';
+ m1_ = m1;
+ }
+ else if(m1->stride[(transpose_r == 'n' ? 1 : 0)] == 1 &&
+ m1->stride[(transpose_r == 'n' ? 0 : 1)] != 0)
+ {
+ transpose_m1 = 't';
+ m1_ = m1;
+ }
+ else
+ {
+ transpose_m1 = (transpose_r == 'n' ? 't' : 'n');
+ m1_ = THCTensor_(newContiguous)(state, m1);
+ }
+
+ /* m2 */
+ if(m2->stride[(transpose_r == 'n' ? 0 : 1)] == 1 &&
+ m2->stride[(transpose_r == 'n' ? 1 : 0)] != 0)
+ {
+ transpose_m2 = 'n';
+ m2_ = m2;
+ }
+ else if(m2->stride[(transpose_r == 'n' ? 1 : 0)] == 1 &&
+ m2->stride[(transpose_r == 'n' ? 0 : 1)] != 0)
+ {
+ transpose_m2 = 't';
+ m2_ = m2;
+ }
+ else
+ {
+ transpose_m2 = (transpose_r == 'n' ? 't' : 'n');
+ m2_ = THCTensor_(newContiguous)(state, m2);
+ }
+
+#ifdef THC_REAL_IS_HALF
+ THCudaBlas_Hgemm(state,
+ transpose_m1,
+ transpose_m2,
+ r__->size[(transpose_r == 'n' ? 0 : 1)],
+ r__->size[(transpose_r == 'n' ? 1 : 0)],
+ m1_->size[(transpose_r == 'n' ? 1 : 0)],
+ alpha,
+ THCTensor_(data)(state, m1_),
+ (transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]),
+ THCTensor_(data)(state, m2_),
+ (transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]),
+ beta,
+ THCTensor_(data)(state, r__),
+ r__->stride[(transpose_r == 'n' ? 1 : 0)]);
+#elif defined(THC_REAL_IS_FLOAT)
+ THCudaBlas_Sgemm(state,
+ transpose_m1,
+ transpose_m2,
+ r__->size[(transpose_r == 'n' ? 0 : 1)],
+ r__->size[(transpose_r == 'n' ? 1 : 0)],
+ m1_->size[(transpose_r == 'n' ? 1 : 0)],
+ alpha,
+ THCTensor_(data)(state, m1_),
+ (transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]),
+ THCTensor_(data)(state, m2_),
+ (transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]),
+ beta,
+ THCTensor_(data)(state, r__),
+ r__->stride[(transpose_r == 'n' ? 1 : 0)]);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_Dgemm(state,
+ transpose_m1,
+ transpose_m2,
+ r__->size[(transpose_r == 'n' ? 0 : 1)],
+ r__->size[(transpose_r == 'n' ? 1 : 0)],
+ m1_->size[(transpose_r == 'n' ? 1 : 0)],
+ alpha,
+ THCTensor_(data)(state, m1_),
+ (transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]),
+ THCTensor_(data)(state, m2_),
+ (transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]),
+ beta,
+ THCTensor_(data)(state, r__),
+ r__->stride[(transpose_r == 'n' ? 1 : 0)]);
+#endif
+
+ /* free intermediate variables */
+ if(m1_ != m1) {
+ THCTensor_(free)(state, m1_);
+ }
+
+ if(m2_ != m2) {
+ THCTensor_(free)(state, m2_);
+ }
+
+ if(r__ != r_) {
+ THCTensor_(freeCopyTo)(state, r__, r_);
+ }
+#else
+ THError("unimplemented data type");
+#endif
+}
+
+THC_API void
+THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
+ real alpha, THCTensor *batch1, THCTensor *batch2) {
+#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ THAssert(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
+ THArgCheck(THCTensor_(nDimension)(state, t) == 2, 4, "expected 2D tensor");
+ THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
+ THArgCheck(THCTensor_(nDimension)(state, batch2) == 3, 7, "expected 3D tensor");
+
+ long batchnum = THCTensor_(size)(state, batch1, 0);
+ long m1d1 = THCTensor_(size)(state, batch1, 1);
+ long innerdim = THCTensor_(size)(state, batch1, 2);
+ long m2d2 = THCTensor_(size)(state, batch2, 2);
+
+ THArgCheck(batchnum == THCTensor_(size)(state, batch2, 0), 7,
+ "equal number of batches expected");
+ // M is t, as listed in the docs under addbmm
+ THArgCheck(m1d1 == THCTensor_(size)(state, t, 0), 6,
+ "first dimension must match first dimension of M");
+ THArgCheck(m2d2 == THCTensor_(size)(state, t, 1), 7,
+ "second dimension must match second dimension of M");
+ THArgCheck(innerdim == THCTensor_(size)(state, batch2, 1), 6,
+ "second dimension must match first dimension of batch2");
+
+ if (t != result) {
+ THCTensor_(resizeAs)(state, result, t);
+ THCTensor_(copy)(state, result, t);
+ }
+
+ THCTensor *slice1 = THCTensor_(new)(state);
+ THCTensor *slice2 = THCTensor_(new)(state);
+ for (long i=0; i<batchnum; i++) {
+ THCTensor_(select)(state, slice1, batch1, 0, i);
+ THCTensor_(select)(state, slice2, batch2, 0, i);
+
+ THCTensor_(addmm)(state, result, beta, result, alpha, slice1, slice2);
+ beta = ScalarConvert<int, real>::to(1);
+ }
+ THCTensor_(free)(state, slice1);
+ THCTensor_(free)(state, slice2);
+#else
+ THError("unimplemented data type");
+#endif
+}
+
+THC_API void
+THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
+ real alpha, THCTensor *batch1, THCTensor *batch2) {
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+ THAssert(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
+ THArgCheck(THCTensor_(nDimension)(state, t) == 3, 4, "expected 3D tensor");
+ THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
+ THArgCheck(THCTensor_(nDimension)(state, batch2) == 3, 7, "expected 3D tensor");
+ THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch1, 0), 6,
+ "equal number of batches expected");
+ THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch2, 0), 7,
+ "equal number of batches expected");
+ THArgCheck(THCTensor_(size)(state, t, 1) == THCTensor_(size)(state, batch1, 1), 6,
+ "wrong matrix size");
+ THArgCheck(THCTensor_(size)(state, t, 2) == THCTensor_(size)(state, batch2, 2), 7,
+ "wrong matrix size");
+ THArgCheck(THCTensor_(size)(state, batch1, 2) == THCTensor_(size)(state, batch2, 1), 6,
+ "wrong matrix size");
+
+ if (t != result) {
+ THCTensor_(resizeAs)(state, result, t);
+ THCTensor_(copy)(state, result, t);
+ }
+
+ bool transpose_result;
+ char transpose_batch1, transpose_batch2;
+ long lda, ldb, ldc;
+ THCTensor *result_, *batch1_, *batch2_;
+ if (result->stride[1] == 1)
+ {
+ transpose_result = false;
+ result_ = result;
+ ldc = result_->stride[2];
+ }
+ else if (result->stride[2] == 1)
+ {
+ transpose_result = true;
+
+ THCTensor *swap = batch2;
+ batch2 = batch1;
+ batch1 = swap;
+
+ result_ = result;
+ ldc = result_->stride[1];
+ }
+ else
+ {
+ transpose_result = false;
+
+ THCTensor *transp_r_ = THCTensor_(newTranspose)(state, result, 1, 2);
+ result_ = THCTensor_(newClone)(state, transp_r_);
+ THCTensor_(free)(state, transp_r_);
+ THCTensor_(transpose)(state, result_, NULL, 1, 2);
+
+ ldc = result_->stride[2];
+ }
+
+ if (batch1->stride[transpose_result ? 2 : 1] == 1)
+ {
+ transpose_batch1 = 'n';
+ batch1_ = batch1;
+ lda = batch1_->stride[transpose_result ? 1 : 2];
+ }
+ else if (batch1->stride[transpose_result ? 1 : 2] == 1)
+ {
+ transpose_batch1 = 't';
+ batch1_ = batch1;
+ lda = batch1_->stride[transpose_result ? 2 : 1];
+ }
+ else
+ {
+ transpose_batch1 = transpose_result ? 'n' : 't';
+ batch1_ = THCTensor_(newContiguous)(state, batch1);
+ lda = batch1_->stride[1];
+ }
+
+ if (batch2->stride[transpose_result ? 2 : 1] == 1)
+ {
+ transpose_batch2 = 'n';
+ batch2_ = batch2;
+ ldb = batch2_->stride[transpose_result ? 1 : 2];
+ }
+ else if (batch2->stride[transpose_result ? 1 : 2] == 1)
+ {
+ transpose_batch2 = 't';
+ batch2_ = batch2;
+ ldb = batch2_->stride[transpose_result ? 2 : 1];
+ }
+ else
+ {
+ transpose_batch2 = transpose_result ? 'n' : 't';
+ batch2_ = THCTensor_(newContiguous)(state, batch2);
+ ldb = batch2_->stride[1];
+ }
+
+ // Compute pointers to matrices in each batch.
+ long num_batches = result_->size[0];
+ size_t matrices_size = num_batches * sizeof(real*);
+ const real **matrices1 = (const real **)THAlloc(matrices_size);
+ const real **matrices2 = (const real **)THAlloc(matrices_size);
+ real **result_matrices = (real **)THAlloc(matrices_size);
+ for (int i = 0; i < num_batches; ++i)
+ {
+ matrices1[i] = THCTensor_(data)(state, batch1_) + i * batch1_->stride[0];
+ matrices2[i] = THCTensor_(data)(state, batch2_) + i * batch2_->stride[0];
+ result_matrices[i] = THCTensor_(data)(state, result_) + i * result_->stride[0];
+ }
+
+ // Copy pointers to device.
+ const real **d_matrices1, **d_matrices2;
+ real **d_result_matrices;
+ THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
+ THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
+ THCudaCheck(THCudaMalloc(state, (void**)&d_result_matrices, matrices_size));
+
+ THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
+ cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+ THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size,
+ cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+ THCudaCheck(cudaMemcpyAsync(d_result_matrices, result_matrices, matrices_size,
+ cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+
+#ifdef THC_REAL_IS_FLOAT
+ THCudaBlas_SgemmBatched(
+ state,
+ transpose_batch1,
+ transpose_batch2,
+ result_->size[transpose_result ? 2 : 1],
+ result_->size[transpose_result ? 1 : 2],
+ batch1_->size[transpose_result ? 1 : 2],
+ alpha,
+ d_matrices1, lda,
+ d_matrices2, ldb,
+ beta,
+ d_result_matrices, ldc,
+ num_batches);
+#elif defined(THC_REAL_IS_DOUBLE)
+ THCudaBlas_DgemmBatched(
+ state,
+ transpose_batch1,
+ transpose_batch2,
+ result_->size[transpose_result ? 2 : 1],
+ result_->size[transpose_result ? 1 : 2],
+ batch1_->size[transpose_result ? 1 : 2],
+ alpha,
+ d_matrices1, lda,
+ d_matrices2, ldb,
+ beta,
+ d_result_matrices, ldc,
+ num_batches);
+#endif
+
+ THCudaFree(state, d_matrices1);
+ THCudaFree(state, d_matrices2);
+ THCudaFree(state, d_result_matrices);
+ THFree(matrices1);
+ THFree(matrices2);
+ THFree(result_matrices);
+
+ if (batch1_ != batch1) {
+ THCTensor_(free)(state, batch1_);
+ }
+
+ if (batch2_ != batch2) {
+ THCTensor_(free)(state, batch2_);
+ }
+
+ if (result_ != result) {
+ THCTensor_(freeCopyTo)(state, result_, result);
+ }
+
+#else
+ THError("unimplemented data type");
+#endif
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathBlas.h b/lib/THC/generic/THCTensorMathBlas.h
new file mode 100644
index 0000000..68f95e3
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathBlas.h
@@ -0,0 +1,13 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathBlas.h"
+#else
+
+THC_API real THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src);
+THC_API void THCTensor_(addmv)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec);
+THC_API void THCTensor_(addmm)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *mat1, THCTensor *mat2);
+THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2);
+THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2);
+THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2);
+
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathCompare.cu b/lib/THC/generic/THCTensorMathCompare.cu
new file mode 100644
index 0000000..77f1ab5
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathCompare.cu
@@ -0,0 +1,101 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathCompare.cu"
+#else
+
+THC_API void THCTensor_(ltValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorLTValueOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>(value));
+}
+
+THC_API void THCTensor_(gtValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorGTValueOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>(value));
+}
+
+THC_API void THCTensor_(leValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorLEValueOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>(value));
+}
+
+THC_API void THCTensor_(geValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorGEValueOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>(value));
+}
+
+THC_API void THCTensor_(eqValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorEQValueOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>(value));
+}
+
+THC_API void THCTensor_(neValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorNEValueOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>(value));
+}
+
+THC_API void THCTensor_(ltValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorLTValueOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(value));
+}
+
+THC_API void THCTensor_(gtValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorGTValueOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(value));
+}
+
+THC_API void THCTensor_(leValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorLEValueOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(value));
+}
+
+THC_API void THCTensor_(geValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorGEValueOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(value));
+}
+
+THC_API void THCTensor_(eqValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorEQValueOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(value));
+}
+
+THC_API void THCTensor_(neValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ THC_logicalValue(state, self_, src,
+ TensorNEValueOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(value));
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathCompare.h b/lib/THC/generic/THCTensorMathCompare.h
new file mode 100644
index 0000000..7b8837c
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathCompare.h
@@ -0,0 +1,20 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathCompare.h"
+#else
+
+THC_API void THCTensor_(ltValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(gtValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(leValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(geValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(eqValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(neValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value);
+
+THC_API void THCTensor_(ltValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(gtValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(leValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(geValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(eqValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value);
+THC_API void THCTensor_(neValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value);
+
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathCompareT.cu b/lib/THC/generic/THCTensorMathCompareT.cu
new file mode 100644
index 0000000..4b59abf
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathCompareT.cu
@@ -0,0 +1,113 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathCompareT.cu"
+#else
+
+THC_API void
+THCTensor_(ltTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorLTOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>());
+}
+
+THC_API void
+THCTensor_(gtTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorGTOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>());
+}
+
+THC_API void
+THCTensor_(leTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorLEOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>());
+}
+
+THC_API void
+THCTensor_(geTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorGEOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>());
+}
+
+THC_API void
+THCTensor_(eqTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorEQOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>());
+}
+
+THC_API void
+THCTensor_(neTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorNEOp<typename TensorUtils<THCTensor>::DataType,
+ unsigned char>());
+}
+
+THC_API void
+THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorLTOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>());
+}
+
+THC_API void
+THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorGTOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>());
+}
+
+THC_API void
+THCTensor_(leTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorLEOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>());
+}
+
+THC_API void
+THCTensor_(geTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorGEOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>());
+}
+
+THC_API void
+THCTensor_(eqTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorEQOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>());
+}
+
+THC_API void
+THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THC_logicalTensor(state, self_, src1, src2,
+ TensorNEOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>());
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathCompareT.h b/lib/THC/generic/THCTensorMathCompareT.h
new file mode 100644
index 0000000..0d76835
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathCompareT.h
@@ -0,0 +1,19 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathCompareT.h"
+#else
+
+THC_API void THCTensor_(ltTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(gtTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(leTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(geTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(eqTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(neTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
+
+THC_API void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(leTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(geTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(eqTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
new file mode 100644
index 0000000..e17013c
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -0,0 +1,135 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathReduce.cu"
+#else
+
+THC_API void
+THCTensor_(sum)(THCState* state, THCTensor *self, THCTensor *src, long dimension) {
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ if (!THC_reduceDim(state, self, src,
+ thrust::identity<real>(),
+ ReduceAdd<real, real>(),
+ ScalarConvert<int, real>::to(0),
+ dimension)) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(prod)(THCState* state, THCTensor *self, THCTensor *src, long dimension) {
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ if (!THC_reduceDim(state, self, src,
+ thrust::identity<real>(),
+ ReduceMultiply<real, real>(),
+ ScalarConvert<int, real>::to(1),
+ dimension)) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API accreal
+THCTensor_(sumall)(THCState *state, THCTensor *self) {
+ THAssert(THCTensor_(checkGPU)(state, 1, self));
+ accreal val;
+ if (!THC_reduceAll(state, self,
+ thrust::identity<real>(),
+ ReduceAdd<real, accreal>(),
+ ReduceAdd<accreal, accreal>(),
+ ScalarConvert<int, accreal>::to(0),
+ &val, 0)) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return val;
+}
+
+THC_API accreal
+THCTensor_(prodall)(THCState *state, THCTensor *self) {
+ THAssert(THCTensor_(checkGPU)(state, 1, self));
+ accreal val;
+ if (!THC_reduceAll(state, self,
+ thrust::identity<real>(),
+ ReduceMultiply<real, accreal>(),
+ ReduceMultiply<accreal, accreal>(),
+ ScalarConvert<int, accreal>::to(1),
+ &val, 0)) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return val;
+}
+
+THC_API real
+THCTensor_(minall)(THCState *state, THCTensor *self) {
+ THAssert(THCTensor_(checkGPU)(state, 1, self));
+ real val;
+ if (!THC_reduceAll(state, self,
+ thrust::identity<real>(),
+ ReduceMin<real>(),
+ ReduceMin<real>(),
+ THCNumerics<real>::max(), &val, 0)) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return val;
+}
+
+THC_API real
+THCTensor_(maxall)(THCState *state, THCTensor *self) {
+ THAssert(THCTensor_(checkGPU)(state, 1, self));
+ real val;
+ if (!THC_reduceAll(state, self,
+ thrust::identity<real>(),
+ ReduceMax<real>(),
+ ReduceMax<real>(),
+ THCNumerics<real>::min(), &val, 0)) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return val;
+}
+
+THC_API void
+THCTensor_(max)(THCState *state,
+ THCTensor *values,
+ THCudaLongTensor *indices,
+ THCTensor *src,
+ long dimension) {
+ THAssert(THCTensor_(checkGPU)(state, 3, values, indices, src));
+
+ thrust::pair<typename TensorUtils<THCTensor>::DataType, long>
+ init =
+ thrust::make_pair<typename TensorUtils<THCTensor>::DataType, long>(
+ THCNumerics<typename TensorUtils<THCTensor>::DataType>::min(), 1);
+
+ return THC_reduceDimIndex(
+ state, values, indices, src, dimension, init,
+ MaxValuePair<typename TensorUtils<THCTensor>::DataType, long>());
+}
+
+THC_API void
+THCTensor_(min)(THCState *state,
+ THCTensor *values,
+ THCudaLongTensor *indices,
+ THCTensor *src,
+ long dimension) {
+ THAssert(THCTensor_(checkGPU)(state, 3, values, indices, src));
+
+ thrust::pair<typename TensorUtils<THCTensor>::DataType, long>
+ init =
+ thrust::make_pair<typename TensorUtils<THCTensor>::DataType, long>(
+ THCNumerics<typename TensorUtils<THCTensor>::DataType>::max(), 1);
+
+ return THC_reduceDimIndex(
+ state, values, indices, src, dimension, init,
+ MinValuePair<typename TensorUtils<THCTensor>::DataType, long>());
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h
new file mode 100644
index 0000000..f584d68
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathReduce.h
@@ -0,0 +1,23 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathReduce.h"
+#else
+
+THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim);
+THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim);
+
+THC_API accreal THCTensor_(sumall)(THCState *state, THCTensor *self);
+THC_API accreal THCTensor_(prodall)(THCState *state, THCTensor *self);
+
+THC_API void THCTensor_(min)(THCState *state,
+ THCTensor *values,
+ THCudaLongTensor *indices,
+ THCTensor *src, long dim);
+THC_API void THCTensor_(max)(THCState *state,
+ THCTensor *values,
+ THCudaLongTensor *indices,
+ THCTensor *src, long dim);
+
+THC_API real THCTensor_(minall)(THCState *state, THCTensor *self);
+THC_API real THCTensor_(maxall)(THCState *state, THCTensor *self);
+
+#endif