#ifndef THC_TENSORMATH_COMPARET_CUH #define THC_TENSORMATH_COMPARET_CUH #include "THCTensorMath.h" #include "THCGeneral.h" #include "THCTensorCopy.h" #include "THCApply.cuh" #include "THCNumerics.cuh" #include "THCReduce.cuh" template struct TensorLTOp { __device__ inline void operator()(TOut* out, T* a, T* b) { *out = ScalarConvert::to(THCNumerics::lt(*a, *b)); } }; template struct TensorGTOp { __device__ inline void operator()(TOut* out, T* a, T* b) { *out = ScalarConvert::to(THCNumerics::gt(*a, *b)); } }; template struct TensorLEOp { __device__ inline void operator()(TOut* out, T* a, T* b) { *out = ScalarConvert::to(THCNumerics::le(*a, *b)); } }; template struct TensorGEOp { __device__ inline void operator()(TOut* out, T* a, T* b) { *out = ScalarConvert::to(THCNumerics::ge(*a, *b)); } }; template struct TensorEQOp { __device__ inline void operator()(TOut* out, T* a, T* b) { *out = ScalarConvert::to(THCNumerics::eq(*a, *b)); } }; template struct TensorNEOp { __device__ inline void operator()(TOut* out, T* a, T* b) { *out = ScalarConvert::to(THCNumerics::ne(*a, *b)); } }; template void THC_logicalTensor(THCState *state, TensorTypeOut *self_, TensorType *src1, TensorType *src2, Op op) { THLongStorage* st = TensorUtils::newSizeOf(state, src1); TensorUtils::resize(state, self_, st, NULL); THLongStorage_free(st); THArgCheck(TensorUtils::getNumElements(state, src1) == TensorUtils::getNumElements(state, src2), 3, "sizes do not match"); if (!THC_pointwiseApply3(state, self_, src1, src2, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } THCudaCheck(cudaGetLastError()); } #endif // THC_TENSORMATH_COMPARET_CUH