#ifndef THC_TENSORMATH_COMPARE_CUH #define THC_TENSORMATH_COMPARE_CUH #include "THCTensorMath.h" #include "THCGeneral.h" #include "THCTensorCopy.h" #include "THCApply.cuh" #include "THCNumerics.cuh" template struct TensorLTValueOp { TensorLTValueOp(T v) : value(v) {} __device__ __forceinline__ void operator()(TOut* out, T* in) { *out = ScalarConvert::to(THCNumerics::lt(*in, value)); } const T value; }; template struct TensorGTValueOp { TensorGTValueOp(T v) : value(v) {} __device__ __forceinline__ void operator()(TOut* out, T* in) { *out = ScalarConvert::to(THCNumerics::gt(*in, value)); } const T value; }; template struct TensorLEValueOp { TensorLEValueOp(T v) : value(v) {} __device__ __forceinline__ void operator()(TOut* out, T* in) { *out = ScalarConvert::to(THCNumerics::le(*in, value)); } const T value; }; template struct TensorGEValueOp { TensorGEValueOp(T v) : value(v) {} __device__ __forceinline__ void operator()(TOut* out, T* in) { *out = ScalarConvert::to(THCNumerics::ge(*in, value)); } const T value; }; template struct TensorEQValueOp { TensorEQValueOp(T v) : value(v) {} __device__ __forceinline__ void operator()(TOut* out, T* in) { *out = ScalarConvert::to(THCNumerics::eq(*in, value)); } const T value; }; template struct TensorNEValueOp { TensorNEValueOp(T v) : value(v) {} __device__ __forceinline__ void operator()(TOut* out, T* in) { *out = ScalarConvert::to(THCNumerics::ne(*in, value)); } const T value; }; template void THC_logicalValue(THCState *state, TensorTypeOut *self_, TensorType *src, Op op) { THLongStorage* st = TensorUtils::newSizeOf(state, src); TensorUtils::resize(state, self_, st, NULL); THLongStorage_free(st); if (!THC_pointwiseApply2(state, self_, src, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } THCudaCheck(cudaGetLastError()); } #endif // THC_TENSORMATH_COMPARE_CUH