Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/generic')
-rw-r--r--lib/THC/generic/THCStorage.c25
-rw-r--r--lib/THC/generic/THCStorage.cu3
-rw-r--r--lib/THC/generic/THCStorage.h14
-rw-r--r--lib/THC/generic/THCStorageCopy.cu4
-rw-r--r--lib/THC/generic/THCStorageCopy.h2
-rw-r--r--lib/THC/generic/THCTensor.c16
-rw-r--r--lib/THC/generic/THCTensor.h18
-rw-r--r--lib/THC/generic/THCTensorCopy.c1
-rw-r--r--lib/THC/generic/THCTensorCopy.cu273
-rw-r--r--lib/THC/generic/THCTensorCopy.h3
-rw-r--r--lib/THC/generic/THCTensorMath.cu68
-rw-r--r--lib/THC/generic/THCTensorMath.h13
-rw-r--r--lib/THC/generic/THCTensorMathPairwise.cu74
-rw-r--r--lib/THC/generic/THCTensorMathPairwise.h10
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu157
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.h11
16 files changed, 403 insertions, 289 deletions
diff --git a/lib/THC/generic/THCStorage.c b/lib/THC/generic/THCStorage.c
index d8fec77..61ba125 100644
--- a/lib/THC/generic/THCStorage.c
+++ b/lib/THC/generic/THCStorage.c
@@ -17,29 +17,18 @@ int THCStorage_(elementSize)(THCState *state)
return sizeof(real);
}
-void THCStorage_(set)(THCState *state, THCStorage *self, long index, hostreal _value)
+void THCStorage_(set)(THCState *state, THCStorage *self, long index, real value)
{
THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds");
- real value = hostrealToReal(_value);
THCudaCheck(cudaMemcpy(self->data + index, &value, sizeof(real), cudaMemcpyHostToDevice));
}
-hostreal THCStorage_(get)(THCState *state, const THCStorage *self, long index)
+real THCStorage_(get)(THCState *state, const THCStorage *self, long index)
{
THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds");
-#ifndef THC_REAL_IS_HALF
real value;
THCudaCheck(cudaMemcpy(&value, self->data + index, sizeof(real), cudaMemcpyDeviceToHost));
- return realToHostreal(value);
-#else
- float *ret_d;
- float ret;
- THCudaCheck(THCudaMalloc(state, (void**)&ret_d, sizeof(float)));
- THCHalf2Float(state, ret_d, self->data + index, 1);
- THCudaCheck(cudaMemcpy(&ret, ret_d, sizeof(float), cudaMemcpyDeviceToHost));
- THCudaFree(state, ret_d);
- return ret;
-#endif
+ return value;
}
THCStorage* THCStorage_(new)(THCState *state)
@@ -80,14 +69,14 @@ THCStorage* THCStorage_(newWithSize)(THCState *state, long size)
}
}
-THCStorage* THCStorage_(newWithSize1)(THCState *state, hostreal data0)
+THCStorage* THCStorage_(newWithSize1)(THCState *state, real data0)
{
THCStorage *self = THCStorage_(newWithSize)(state, 1);
THCStorage_(set)(state, self, 0, data0);
return self;
}
-THCStorage* THCStorage_(newWithSize2)(THCState *state, hostreal data0, hostreal data1)
+THCStorage* THCStorage_(newWithSize2)(THCState *state, real data0, real data1)
{
THCStorage *self = THCStorage_(newWithSize)(state, 2);
THCStorage_(set)(state, self, 0, data0);
@@ -95,7 +84,7 @@ THCStorage* THCStorage_(newWithSize2)(THCState *state, hostreal data0, hostreal
return self;
}
-THCStorage* THCStorage_(newWithSize3)(THCState *state, hostreal data0, hostreal data1, hostreal data2)
+THCStorage* THCStorage_(newWithSize3)(THCState *state, real data0, real data1, real data2)
{
THCStorage *self = THCStorage_(newWithSize)(state, 3);
THCStorage_(set)(state, self, 0, data0);
@@ -104,7 +93,7 @@ THCStorage* THCStorage_(newWithSize3)(THCState *state, hostreal data0, hostreal
return self;
}
-THCStorage* THCStorage_(newWithSize4)(THCState *state, hostreal data0, hostreal data1, hostreal data2, hostreal data3)
+THCStorage* THCStorage_(newWithSize4)(THCState *state, real data0, real data1, real data2, real data3)
{
THCStorage *self = THCStorage_(newWithSize)(state, 4);
THCStorage_(set)(state, self, 0, data0);
diff --git a/lib/THC/generic/THCStorage.cu b/lib/THC/generic/THCStorage.cu
index ff0d3c9..17924f5 100644
--- a/lib/THC/generic/THCStorage.cu
+++ b/lib/THC/generic/THCStorage.cu
@@ -2,10 +2,9 @@
#define THC_GENERIC_FILE "generic/THCStorage.cu"
#else
-void THCStorage_(fill)(THCState *state, THCStorage *self, hostreal _value)
+void THCStorage_(fill)(THCState *state, THCStorage *self, real value)
{
thrust::device_ptr<real> self_data(self->data);
- real value = hostrealToReal(_value);
thrust::fill(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
diff --git a/lib/THC/generic/THCStorage.h b/lib/THC/generic/THCStorage.h
index f161d5f..a8c5f5f 100644
--- a/lib/THC/generic/THCStorage.h
+++ b/lib/THC/generic/THCStorage.h
@@ -23,15 +23,15 @@ THC_API long THCStorage_(size)(THCState *state, const THCStorage*);
THC_API int THCStorage_(elementSize)(THCState *state);
/* slow access -- checks everything */
-THC_API void THCStorage_(set)(THCState *state, THCStorage*, long, hostreal);
-THC_API hostreal THCStorage_(get)(THCState *state, const THCStorage*, long);
+THC_API void THCStorage_(set)(THCState *state, THCStorage*, long, real);
+THC_API real THCStorage_(get)(THCState *state, const THCStorage*, long);
THC_API THCStorage* THCStorage_(new)(THCState *state);
THC_API THCStorage* THCStorage_(newWithSize)(THCState *state, long size);
-THC_API THCStorage* THCStorage_(newWithSize1)(THCState *state, hostreal);
-THC_API THCStorage* THCStorage_(newWithSize2)(THCState *state, hostreal, hostreal);
-THC_API THCStorage* THCStorage_(newWithSize3)(THCState *state, hostreal, hostreal, hostreal);
-THC_API THCStorage* THCStorage_(newWithSize4)(THCState *state, hostreal, hostreal, hostreal, hostreal);
+THC_API THCStorage* THCStorage_(newWithSize1)(THCState *state, real);
+THC_API THCStorage* THCStorage_(newWithSize2)(THCState *state, real, real);
+THC_API THCStorage* THCStorage_(newWithSize3)(THCState *state, real, real, real);
+THC_API THCStorage* THCStorage_(newWithSize4)(THCState *state, real, real, real, real);
THC_API THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *filename, long size, int shared);
/* takes ownership of data */
@@ -49,6 +49,6 @@ THC_API void THCStorage_(retain)(THCState *state, THCStorage *storage);
THC_API void THCStorage_(free)(THCState *state, THCStorage *storage);
THC_API void THCStorage_(resize)(THCState *state, THCStorage *storage, long size);
-THC_API void THCStorage_(fill)(THCState *state, THCStorage *storage, hostreal value);
+THC_API void THCStorage_(fill)(THCState *state, THCStorage *storage, real value);
#endif
diff --git a/lib/THC/generic/THCStorageCopy.cu b/lib/THC/generic/THCStorageCopy.cu
index b00c122..298f717 100644
--- a/lib/THC/generic/THCStorageCopy.cu
+++ b/lib/THC/generic/THCStorageCopy.cu
@@ -62,7 +62,7 @@ THC_CUDA_STORAGE_IMPLEMENT_COPY(Long,Long)
THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,) // i.e. float
THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double)
-#if CUDA_VERSION >= 7050
+#ifdef CUDA_HALF_TENSOR
#define FLOAT_COPY(TYPE) TH_CONCAT_3(TH, CReal, Storage_copyCudaFloat)
void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *self, struct THCudaHalfStorage *src)
{
@@ -77,7 +77,7 @@ void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *self, struct THCudaH
}
}
#undef FLOAT_COPY
-#endif // CUDA_VERSION >= 7050
+#endif // CUDA_HALF_TENSOR
#undef THC_CUDA_STORAGE_IMPLEMENT_COPY
diff --git a/lib/THC/generic/THCStorageCopy.h b/lib/THC/generic/THCStorageCopy.h
index 3313290..c3e5601 100644
--- a/lib/THC/generic/THCStorageCopy.h
+++ b/lib/THC/generic/THCStorageCopy.h
@@ -21,7 +21,7 @@ THC_API void THCStorage_(copyCudaInt)(THCState *state, THCStorage *storage, stru
THC_API void THCStorage_(copyCudaLong)(THCState *state, THCStorage *storage, struct THCudaLongStorage *src);
THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, struct THCudaStorage *src);
THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src);
-#if CUDA_VERSION >= 7050
+#ifdef CUDA_HALF_TENSOR
THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src);
#endif
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c
index 2f87f1a..e18044d 100644
--- a/lib/THC/generic/THCTensor.c
+++ b/lib/THC/generic/THCTensor.c
@@ -730,56 +730,56 @@ void THCTensor_(rawResize)(THCState *state, THCTensor *self, int nDimension, lon
self->nDimension = 0;
}
-void THCTensor_(set1d)(THCState *state, THCTensor *tensor, long x0, hostreal value)
+void THCTensor_(set1d)(THCState *state, THCTensor *tensor, long x0, real value)
{
THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension");
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range");
THCStorage_(set)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0], value);
}
-hostreal THCTensor_(get1d)(THCState *state, const THCTensor *tensor, long x0)
+real THCTensor_(get1d)(THCState *state, const THCTensor *tensor, long x0)
{
THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension");
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range");
return THCStorage_(get)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0]);
}
-void THCTensor_(set2d)(THCState *state, THCTensor *tensor, long x0, long x1, hostreal value)
+void THCTensor_(set2d)(THCState *state, THCTensor *tensor, long x0, long x1, real value)
{
THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions");
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range");
THCStorage_(set)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1], value);
}
-hostreal THCTensor_(get2d)(THCState *state, const THCTensor *tensor, long x0, long x1)
+real THCTensor_(get2d)(THCState *state, const THCTensor *tensor, long x0, long x1)
{
THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions");
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range");
return THCStorage_(get)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]);
}
-void THCTensor_(set3d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, hostreal value)
+void THCTensor_(set3d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, real value)
{
THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions");
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range");
THCStorage_(set)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2], value);
}
-hostreal THCTensor_(get3d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2)
+real THCTensor_(get3d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2)
{
THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions");
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range");
return THCStorage_(get)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]);
}
-void THCTensor_(set4d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, long x3, hostreal value)
+void THCTensor_(set4d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, long x3, real value)
{
THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions");
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range");
THCStorage_(set)(state, tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3], value);
}
-hostreal THCTensor_(get4d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2, long x3)
+real THCTensor_(get4d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2, long x3)
{
THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions");
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range");
diff --git a/lib/THC/generic/THCTensor.h b/lib/THC/generic/THCTensor.h
index 8e4d1a4..175eaee 100644
--- a/lib/THC/generic/THCTensor.h
+++ b/lib/THC/generic/THCTensor.h
@@ -112,15 +112,15 @@ THC_API void THCTensor_(free)(THCState *state, THCTensor *self);
THC_API void THCTensor_(freeCopyTo)(THCState *state, THCTensor *self, THCTensor *dst);
/* Slow access methods [check everything] */
-THC_API void THCTensor_(set1d)(THCState *state, THCTensor *tensor, long x0, hostreal value);
-THC_API void THCTensor_(set2d)(THCState *state, THCTensor *tensor, long x0, long x1, hostreal value);
-THC_API void THCTensor_(set3d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, hostreal value);
-THC_API void THCTensor_(set4d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, long x3, hostreal value);
-
-THC_API hostreal THCTensor_(get1d)(THCState *state, const THCTensor *tensor, long x0);
-THC_API hostreal THCTensor_(get2d)(THCState *state, const THCTensor *tensor, long x0, long x1);
-THC_API hostreal THCTensor_(get3d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2);
-THC_API hostreal THCTensor_(get4d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2, long x3);
+THC_API void THCTensor_(set1d)(THCState *state, THCTensor *tensor, long x0, real value);
+THC_API void THCTensor_(set2d)(THCState *state, THCTensor *tensor, long x0, long x1, real value);
+THC_API void THCTensor_(set3d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, real value);
+THC_API void THCTensor_(set4d)(THCState *state, THCTensor *tensor, long x0, long x1, long x2, long x3, real value);
+
+THC_API real THCTensor_(get1d)(THCState *state, const THCTensor *tensor, long x0);
+THC_API real THCTensor_(get2d)(THCState *state, const THCTensor *tensor, long x0, long x1);
+THC_API real THCTensor_(get3d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2);
+THC_API real THCTensor_(get4d)(THCState *state, const THCTensor *tensor, long x0, long x1, long x2, long x3);
/* CUDA-specific functions */
THC_API cudaTextureObject_t THCTensor_(getTextureObject)(THCState *state, THCTensor *self);
diff --git a/lib/THC/generic/THCTensorCopy.c b/lib/THC/generic/THCTensorCopy.c
index 68f57bf..e0bcadd 100644
--- a/lib/THC/generic/THCTensorCopy.c
+++ b/lib/THC/generic/THCTensorCopy.c
@@ -126,7 +126,6 @@ IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Long)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Float)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Double)
-// FIXME: add within-CUDA conversions
void THCTensor_(copyCuda)(THCState *state, THCTensor *self, THCTensor *src)
{
THCTensor_(copy)(state, self, src);
diff --git a/lib/THC/generic/THCTensorCopy.cu b/lib/THC/generic/THCTensorCopy.cu
index c5768e2..4198025 100644
--- a/lib/THC/generic/THCTensorCopy.cu
+++ b/lib/THC/generic/THCTensorCopy.cu
@@ -4,251 +4,44 @@
THC_API void
THCTensor_(copy)(THCState* state, THCTensor* dst, THCTensor* src) {
- long totalElements = THCTensor_(nElement)(state, dst);
-
- THArgCheck(totalElements == THCTensor_(nElement)(state, src), 2,
- "sizes do not match");
-
- if (THCTensor_(nDimension)(state, dst) == 0) {
- // Zero-dim tensor; copy nothing
- return;
- }
-
- // We can memcpy the memory if:
- // -both tensors are contiguous; or,
- // -there is only one element to copy; or,
- // -FIXME: if both tensors have matching size and stride arrays, and no
- // holes within (in other words, there is some permutation that can be applied
- // to the size/strides such that the resulting tensor is contiguous).
- bool srcContig = THCTensor_(isContiguous)(state, src);
- bool dstContig = THCTensor_(isContiguous)(state, dst);
- bool memcpyEligible = (srcContig && dstContig) || (totalElements == 1);
-
- int srcDev = THCTensor_(getDevice)(state, src);
- int dstDev = THCTensor_(getDevice)(state, dst);
- int oldDev = curGPU();
-
- // We always perform the copy on the source device, using the
- // current stream on the source device.
- // If the copy is on the default stream, then we fully synchronize
- // both src and dst's default streams for completion of the
- // copy. We have to explicitly do this for non-contig copies.
- // This mimics the behavior of cross-device cudaMemcpyAsync on
- // the default stream.
- // If the copy is not on the default stream, then it is up to the
- // user to add needed synchronization on the dst device, since the
- // stream on the dst device that wishes to synchronize may not be
- // the same index as the one on the src device.
- int copyStreamIndex =
- THCState_getCurrentStreamIndex(state);
- cudaStream_t copyStream =
- THCState_getDeviceStream(state, srcDev, copyStreamIndex);
-
- if (srcDev != dstDev && copyStreamIndex == 0) {
- // This is a cross-device copy on the default stream. We perform a
- // two-way barrier between both devices' default streams before
- // the copy. This ensures that any write-after-write and
- // write-after-read dependencies on the destination side are
- // handled, so that no one is operating on the dst memory when
- // we perform the copy.
- // src waits on dst barrier (src already waits on src)
- cudaEvent_t dstReady;
- THCudaCheck(cudaSetDevice(dstDev));
- THCudaCheck(cudaEventCreateWithFlags(&dstReady, cudaEventDisableTiming));
- THCudaCheck(cudaEventRecord(dstReady, NULL));
-
- THCudaCheck(cudaSetDevice(srcDev));
- THCudaCheck(cudaStreamWaitEvent(NULL, dstReady, 0));
- THCudaCheck(cudaEventDestroy(dstReady));
- } else if (srcDev != oldDev) {
- THCudaCheck(cudaSetDevice(srcDev));
- }
-
- // We are now on srcDev
- if (memcpyEligible) {
- // Perform the copy
- THCudaCheck(cudaMemcpyAsync(THCTensor_(data)(state, dst),
- THCTensor_(data)(state, src),
- totalElements * sizeof(real),
- cudaMemcpyDeviceToDevice,
- copyStream));
- } else {
-#if defined(THC_REAL_IS_FLOAT)
- // Non-contiguous copy
-
- // We avoid creating temporary memory copies if possible.
- // If both src and dst are on the same device, or if they are on
- // different devices and p2p access is enabled, perform the copy
- // by a pointwise copy kernel.
- // Otherwise, we'll have to make contiguous (which will in fact
- // invoke copy() again), and then perform the copy.
- // FIXME: might want to consider only running the pointwise kernel
- // if both src and dst innermost dimensions are contiguous. If
- // they are not, then taking the hit of the memory allocation/free
- // might be worth it to avoid non-coalesced reads or writes.
-
- // A device always has access to itself, so this also handles the
- // case srcDev == dstDev
- if (THCState_getPeerToPeerAccess(state, srcDev, dstDev)) {
- // Make sure we have the current stream set in THCState, since
- // pointwise uses that
- if (srcDev != oldDev) {
- THCState_setStream(state, srcDev, copyStreamIndex);
- }
-
- bool succ =
- THCudaTensor_pointwiseApply2(state, dst, src, CopyOp<float>());
- THArgCheck(succ, 2, CUTORCH_DIM_WARNING);
-
- // Restore prior THCState stream
- if (srcDev != oldDev) {
- THCState_setStream(state, oldDev, copyStreamIndex);
- }
- } else {
- // GPUs can't access each other directly; fall back to
- // newContiguous and memcpy
- THCudaTensor* srcContig = THCudaTensor_newContiguous(state, src);
- THCudaTensor* dstContig = dst;
-
- if (!THCudaTensor_isContiguous(state, dst)) {
- // We are copying over the contents of dst, so we don't need
- // to preserve its values. We just need a destination tensor
- // the same size as dst.
-
- // Allocate the tensor on the new device
- THCudaCheck(cudaSetDevice(dstDev));
-
- dstContig = THCudaTensor_new(state);
- THCudaTensor_resizeAs(state, dstContig, dst);
-
- THCudaCheck(cudaSetDevice(srcDev));
- }
-
- THCudaCheck(cudaMemcpyAsync(THCudaTensor_data(state, dstContig),
- THCudaTensor_data(state, srcContig),
- totalElements * sizeof(float),
- cudaMemcpyDeviceToDevice,
- copyStream));
-
- THCudaTensor_free(state, srcContig);
-
- if (dst != dstContig) {
- THCudaTensor_freeCopyTo(state, dstContig, dst);
- }
- }
-#else
-#define STRINGIFY(x) #x
- THError("Non-contiguous copy not implemented for Cuda%sTensor", STRINGIFY(Real));
-#undef STRINGIFY
-#endif
- }
-
- if (srcDev != dstDev && copyStreamIndex == 0) {
- // dst waits on src barrier (dst already waits on dst). We cannot
- // operate on dst's copy until the copy is complete.
-
- // Still on srcDev, record default stream event
- cudaEvent_t srcReady;
- THCudaCheck(cudaEventCreateWithFlags(&srcReady, cudaEventDisableTiming));
- THCudaCheck(cudaEventRecord(srcReady, NULL));
-
- THCudaCheck(cudaSetDevice(dstDev));
- THCudaCheck(cudaStreamWaitEvent(NULL, srcReady, 0));
- THCudaCheck(cudaEventDestroy(srcReady));
-
- // We are now on dstDev (right above). Restore prior device from dst
- if (dstDev != oldDev) {
- THCudaCheck(cudaSetDevice(oldDev));
- }
- } else {
- // We are still on srcDev. Restore prior device from src
- if (srcDev != oldDev) {
- THCudaCheck(cudaSetDevice(oldDev));
- }
- }
+ THC_copyTensor<THCTensor, THCTensor>(state, dst, src);
+}
- cudaError errcode = cudaGetLastError();
- if (errcode != cudaSuccess) {
- THError(cudaGetErrorString(errcode));
- }
+THC_API void
+THCTensor_(copyIgnoringOverlaps)(THCState* state, THCTensor* dst, THCTensor* src) {
+ // Called when we are copying into an overlapping index `dst`, but
+ // we don't care which writer wins. Hacky but it works.
+ // This is itself invoked by pointwiseApply2 / THCTensor_copy in
+ // case that there are write overlaps.
+ // FIXME: really, overlapping writes should be illegal/an error in Torch
+ THC_pointwiseApply2(
+ state, dst, src,
+ CopyOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(),
+ ReadOnly, /* ignore overwrites */
+ ReadOnly);
}
-// conversions are mediated by the CPU
-// yes, this is slow; feel free to write CUDA kernels for this
-#ifndef THC_REAL_IS_HALF
-#define THC_CUDA_TENSOR_IMPLEMENT_COPY(TYPEC,TYPECUDA) \
- void THCTensor_(copyCuda##TYPEC)(THCState *state, THCTensor *self, struct THCuda##TYPECUDA##Tensor *src) \
- { \
- if(THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \
- THCTensor_(copy)(state, self, (THCTensor*) src); /* cast just removes compiler warning */ \
- } else { \
- THArgCheck(THCTensor_(nElement)(state, self) == THCuda##TYPECUDA##Tensor_nElement(state, src), 2, "size does not match"); \
- THLongStorage *size = THCuda##TYPECUDA##Tensor_newSizeOf(state, src); \
- TH##TYPEC##Tensor *buffer1 = TH##TYPEC##Tensor_newWithSize(size, NULL); \
- THTensor *buffer2 = THTensor_(newWithSize)(size, NULL); \
- TH##TYPEC##Tensor_copyCuda(state, buffer1, src); \
- THTensor_(copy##TYPEC)(buffer2, buffer1); \
- THCTensor_(copyCPU)(state, self, buffer2); \
- THLongStorage_free(size); \
- TH##TYPEC##Tensor_free(buffer1); \
- THTensor_(free)(buffer2); \
- } \
- }
-#else
-#define THC_CUDA_TENSOR_IMPLEMENT_COPY(TYPEC,TYPECUDA) \
- void THCTensor_(copyCuda##TYPEC)(THCState *state, THCTensor *self, struct THCuda##TYPECUDA##Tensor *src) \
- { \
- THArgCheck(THCTensor_(nElement)(state, self) == THCuda##TYPECUDA##Tensor_nElement(state, src), 2, "size does not match"); \
- if (THCTypeIdx_(TYPEC) == THCTypeIdxFloat) { \
- THCudaTensor *csrc = THCudaTensor_newContiguous(state, (THCudaTensor*) src); /* cast removes compiler error */ \
- THCFloat2Half(state, \
- THCTensor_(data)(state, self), \
- THCudaTensor_data(state, csrc), \
- THCudaTensor_nElement(state, csrc)); \
- THCudaTensor_free(state, csrc); \
- } else { \
- THLongStorage *size = THCuda##TYPECUDA##Tensor_newSizeOf(state, src); \
- THCudaTensor *buffer = THCudaTensor_newWithSize(state, size, NULL); \
- THCudaTensor_copyCuda##TYPEC(state, buffer, src); \
- THCFloat2Half(state, \
- THCTensor_(data)(state, self), \
- THCudaTensor_data(state, buffer), \
- THCudaTensor_nElement(state, buffer)); \
- THCudaTensor_free(state, buffer); \
- THLongStorage_free(size); \
- } \
+#define IMPLEMENT_THC_CUDA_TENSOR_COPY(TYPEC, TYPECUDA) \
+ THC_API void \
+ THCTensor_(copyCuda##TYPEC)(THCState *state, \
+ THCTensor *self, \
+ THCuda##TYPECUDA##Tensor *src) { \
+ THC_copyTensor<THCTensor, THCuda##TYPECUDA##Tensor>(state, self, src); \
}
-#endif
-
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Byte,Byte)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Char,Char)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Short,Short)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Int,Int)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Long,Long)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Float,) // i.e. float
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Double,Double)
-#if CUDA_VERSION >= 7050
-#define FLOAT_COPY(TYPE) TH_CONCAT_3(TH, CReal, Tensor_copyCudaFloat)
-void THCTensor_(copyCudaHalf)(THCState *state, THCTensor *self, struct THCudaHalfTensor *src)
-{
- if(THCTypeIdx_(Real) == THCTypeIdxHalf) {
- THCTensor_(copy)(state, self, (THCTensor*) src); /* cast removes compiler error */
- } else {
- THArgCheck(THCTensor_(nElement)(state, self) == THCudaHalfTensor_nElement(state, src), 2, "size does not match");
- src = THCudaHalfTensor_newContiguous(state, src);
- THLongStorage *size = THCudaHalfTensor_newSizeOf(state, src);
- THCudaTensor *buffer = THCudaTensor_newWithSize(state, size, NULL);
- THCHalf2Float(state, THCudaTensor_data(state, buffer), THCudaHalfTensor_data(state, src), THCudaHalfTensor_nElement(state, src));
- FLOAT_COPY(Real)(state, self, buffer);
- THCudaTensor_free(state, buffer);
- THCudaHalfTensor_free(state, src);
- THLongStorage_free(size);
- }
-}
-#undef FLOAT_COPY
-#endif // CUDA_VERSION >= 7050
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Byte, Byte)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Char, Char)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Short, Short)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Int, Int)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Long, Long)
+// THCudaTensor aka the non-existent THCudaFloatTensor
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Float, )
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Double, Double)
+#ifdef CUDA_HALF_TENSOR
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Half, Half)
+#endif
-#undef THC_CUDA_TENSOR_IMPLEMENT_COPY
+#undef IMPLEMENT_THC_CUDA_TENSOR_COPY
#endif
diff --git a/lib/THC/generic/THCTensorCopy.h b/lib/THC/generic/THCTensorCopy.h
index 8a2837f..3c7649b 100644
--- a/lib/THC/generic/THCTensorCopy.h
+++ b/lib/THC/generic/THCTensorCopy.h
@@ -3,6 +3,7 @@
#else
THC_API void THCTensor_(copy)(THCState *state, THCTensor *self, THCTensor *src);
+THC_API void THCTensor_(copyIgnoringOverlaps)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(copyByte)(THCState *state, THCTensor *self, THByteTensor *src);
THC_API void THCTensor_(copyChar)(THCState *state, THCTensor *self, THCharTensor *src);
THC_API void THCTensor_(copyShort)(THCState *state, THCTensor *self, THShortTensor *src);
@@ -18,7 +19,7 @@ THC_API void THCTensor_(copyCudaInt)(THCState *state, THCTensor *storage, struct
THC_API void THCTensor_(copyCudaLong)(THCState *state, THCTensor *storage, struct THCudaLongTensor *src);
THC_API void THCTensor_(copyCudaFloat)(THCState *state, THCTensor *storage, struct THCudaTensor *src);
THC_API void THCTensor_(copyCudaDouble)(THCState *state, THCTensor *storage, struct THCudaDoubleTensor *src);
-#if CUDA_VERSION >= 7050
+#ifdef CUDA_HALF_TENSOR
THC_API void THCTensor_(copyCudaHalf)(THCState *state, THCTensor *storage, struct THCudaHalfTensor *src);
#endif
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu
new file mode 100644
index 0000000..a0e550a
--- /dev/null
+++ b/lib/THC/generic/THCTensorMath.cu
@@ -0,0 +1,68 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMath.cu"
+#else
+
+THC_API void
+THCTensor_(fill)(THCState* state, THCTensor *self_, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, self_));
+
+ if (!THC_pointwiseApply1(
+ state, self_, TensorFillOp<real>(value))) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(zero)(THCState *state, THCTensor *self_)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, self_));
+ if (THCTensor_(isContiguous)(state, self_)) {
+ THCudaCheck(cudaMemsetAsync(THCTensor_(data)(state, self_),
+ 0,
+ sizeof(real) * THCTensor_(nElement)(state, self_),
+ THCState_getCurrentStream(state)));
+ } else {
+ if (!THC_pointwiseApply1(
+ state, self_,
+ TensorFillOp<real>(ScalarConvert<int, real>::to(0)))) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, r_));
+ THCTensor_(resize)(state, r_, size, NULL);
+ THCTensor_(zero)(state, r_);
+}
+
+THC_API void
+THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, r_));
+ THCTensor_(resize)(state, r_, size, NULL);
+ THCTensor_(fill)(state, r_, ScalarConvert<int, real>::to(1));
+}
+
+THC_API void
+THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, r_, t));
+ THCTensor_(resize)(state, r_, size, NULL);
+ THCTensor_(copy)(state, r_, t);
+}
+
+long
+THCTensor_(numel)(THCState *state, THCTensor *t)
+{
+ return THCTensor_(nElement)(state, t);
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h
new file mode 100644
index 0000000..5c9e66d
--- /dev/null
+++ b/lib/THC/generic/THCTensorMath.h
@@ -0,0 +1,13 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMath.h"
+#else
+
+THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, real value);
+THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
+
+THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size);
+THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size);
+THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size);
+THC_API long THCTensor_(numel)(THCState *state, THCTensor *t);
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathPairwise.cu b/lib/THC/generic/THCTensorMathPairwise.cu
new file mode 100644
index 0000000..4a5c09d
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathPairwise.cu
@@ -0,0 +1,74 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathPairwise.cu"
+#else
+
+THC_API void
+THCTensor_(add)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
+ if (self_ == src_) {
+ if (!THC_pointwiseApply1(state, self_, TensorAddConstantOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src_);
+
+ if (!THC_pointwiseApply2(state, self_, src_, TensorAddConstantOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(sub)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
+{
+ THCTensor_(add)(state, self_, src_, ScalarNegate<real>::to(value));
+}
+
+THC_API void
+THCTensor_(mul)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
+ if (self_ == src_) {
+ if (!THC_pointwiseApply1(state, self_, TensorMulConstantOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src_);
+
+ if (!THC_pointwiseApply2(state, self_, src_, TensorMulConstantOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(div)(THCState* state, THCTensor *self_, THCTensor *src_, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
+ THArgCheck(value != ScalarConvert<int, real>::to(0), 3, "divide by zero");
+
+ if (self_ == src_) {
+ if (!THC_pointwiseApply1(state, self_,
+ TensorMulConstantOp<real>(
+ ScalarInv<real>::to(value)))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src_);
+
+ if (!THC_pointwiseApply2(state, self_, src_,
+ TensorMulConstantOp<real>(
+ ScalarInv<real>::to(value)))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathPairwise.h b/lib/THC/generic/THCTensorMathPairwise.h
new file mode 100644
index 0000000..9a83293
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathPairwise.h
@@ -0,0 +1,10 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathPairwise.h"
+#else
+
+THC_API void THCTensor_(add)(THCState *state, THCTensor *self, THCTensor *src, real value);
+THC_API void THCTensor_(sub)(THCState *state, THCTensor *self, THCTensor *src, real value);
+THC_API void THCTensor_(mul)(THCState *state, THCTensor *self, THCTensor *src, real value);
+THC_API void THCTensor_(div)(THCState *state, THCTensor *self, THCTensor *src, real value);
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
new file mode 100644
index 0000000..b6679cd
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -0,0 +1,157 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathPointwise.cu"
+#else
+
+THC_API void
+THCTensor_(cadd)(THCState *state, THCTensor *self_, THCTensor* src1, real value, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 3, "sizes do not match");
+
+ if (self_ == src1) {
+ if (value == ScalarConvert<int, real>::to(1)) {
+ // self += src2
+ if (!THC_pointwiseApply2(state, self_, src2, TensorAddOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ // self += value * src2
+ if (!THC_pointwiseApply2(state, self_, src2, TensorCAddOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src1);
+
+ if (value == ScalarConvert<int, real>::to(1)) {
+ // self = src1 + src2
+ if (!THC_pointwiseApply3(state, self_, src1, src2, TensorAddOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ // self = src1 + value * src2
+ if (!THC_pointwiseApply3(state, self_, src1, src2, TensorCAddOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(csub)(THCState *state, THCTensor *self_, THCTensor* src1, real value, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 3, "sizes do not match");
+
+ if (self_ == src1) {
+ if (value == ScalarConvert<int, real>::to(1)) {
+ // self -= src2
+ if (!THC_pointwiseApply2(state, self_, src2, TensorSubOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ // self += -value * src2
+ if (!THC_pointwiseApply2(state, self_, src2,
+ TensorCAddOp<real>(
+ ScalarNegate<real>::to(value)))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src1);
+
+ if (value == ScalarConvert<int, real>::to(1)) {
+ // self = src1 - src2
+ if (!THC_pointwiseApply3(state, self_, src1, src2, TensorSubOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ // self = src1 - value * src2
+ if (!THC_pointwiseApply3(state, self_, src1, src2,
+ TensorCAddOp<real>(
+ ScalarNegate<real>::to(value)))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(cmul)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 3, "sizes do not match");
+
+ if (self_ == src1) {
+ // self *= src2
+ if (!THC_pointwiseApply2(state, self_, src2, TensorMulOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src1);
+
+ // self = src1 * src2
+ if (!THC_pointwiseApply3(state, self_, src1, src2, TensorMulOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(cpow)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 3, "sizes do not match");
+
+ if (self_ == src1) {
+ // self = pow(self, src2)
+ if (!THC_pointwiseApply2(state, self_, src2, TensorCPowOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src1);
+
+ // self = pow(src1, src2)
+ if (!THC_pointwiseApply3(state, self_, src1, src2, TensorCPowOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
+THCTensor_(cdiv)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 3, "sizes do not match");
+
+ if (self_ == src1) {
+ // self *= src2
+ if (!THC_pointwiseApply2(state, self_, src2, TensorDivOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src1);
+
+ // self = src1 * src2
+ if (!THC_pointwiseApply3(state, self_, src1, src2, TensorDivOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h
new file mode 100644
index 0000000..cfb3b14
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathPointwise.h
@@ -0,0 +1,11 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathPointwise.h"
+#else
+
+THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2);
+THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2);
+THC_API void THCTensor_(cmul)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(cpow)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(cdiv)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
+
+#endif