Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/generic/THCTensorCopy.cu')
-rw-r--r--lib/THC/generic/THCTensorCopy.cu69
1 files changed, 51 insertions, 18 deletions
diff --git a/lib/THC/generic/THCTensorCopy.cu b/lib/THC/generic/THCTensorCopy.cu
index 5aa7ee5..304e52c 100644
--- a/lib/THC/generic/THCTensorCopy.cu
+++ b/lib/THC/generic/THCTensorCopy.cu
@@ -1,19 +1,15 @@
-#include "THCApply.cuh"
-
-static inline int curGPU() {
- int curDev;
- THCudaCheck(cudaGetDevice(&curDev));
- return curDev;
-}
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorCopy.cu"
+#else
THC_API void
-THCudaTensor_copy(THCState* state, THCudaTensor* dst, THCudaTensor* src) {
- long totalElements = THCudaTensor_nElement(state, dst);
+THCTensor_(copy)(THCState* state, THCTensor* dst, THCTensor* src) {
+ long totalElements = THCTensor_(nElement)(state, dst);
- THArgCheck(totalElements == THCudaTensor_nElement(state, src), 2,
+ THArgCheck(totalElements == THCTensor_(nElement)(state, src), 2,
"sizes do not match");
- if (THCudaTensor_nDimension(state, dst) == 0) {
+ if (THCTensor_(nDimension)(state, dst) == 0) {
// Zero-dim tensor; copy nothing
return;
}
@@ -24,12 +20,12 @@ THCudaTensor_copy(THCState* state, THCudaTensor* dst, THCudaTensor* src) {
// -FIXME: if both tensors have matching size and stride arrays, and no
// holes within (in other words, there is some permutation that can be applied
// to the size/strides such that the resulting tensor is contiguous).
- bool srcContig = THCudaTensor_isContiguous(state, src);
- bool dstContig = THCudaTensor_isContiguous(state, dst);
+ bool srcContig = THCTensor_(isContiguous)(state, src);
+ bool dstContig = THCTensor_(isContiguous)(state, dst);
bool memcpyEligible = (srcContig && dstContig) || (totalElements == 1);
- int srcDev = THCudaTensor_getDevice(state, src);
- int dstDev = THCudaTensor_getDevice(state, dst);
+ int srcDev = THCTensor_(getDevice)(state, src);
+ int dstDev = THCTensor_(getDevice)(state, dst);
int oldDev = curGPU();
// We always perform the copy on the source device, using the
@@ -71,12 +67,13 @@ THCudaTensor_copy(THCState* state, THCudaTensor* dst, THCudaTensor* src) {
// We are now on srcDev
if (memcpyEligible) {
// Perform the copy
- THCudaCheck(cudaMemcpyAsync(THCudaTensor_data(state, dst),
- THCudaTensor_data(state, src),
- totalElements * sizeof(float),
+ THCudaCheck(cudaMemcpyAsync(THCTensor_(data)(state, dst),
+ THCTensor_(data)(state, src),
+ totalElements * sizeof(real),
cudaMemcpyDeviceToDevice,
copyStream));
} else {
+#ifdef THC_REAL_IS_FLOAT
// Non-contiguous copy
// We avoid creating temporary memory copies if possible.
@@ -139,6 +136,11 @@ THCudaTensor_copy(THCState* state, THCudaTensor* dst, THCudaTensor* src) {
THCudaTensor_freeCopyTo(state, dstContig, dst);
}
}
+#else
+#define STRINGIFY(x) #x
+ THError("Non-contiguous copy not implemented for Cuda%sTensor", STRINGIFY(Real));
+#undef STRINGIFY
+#endif
}
if (srcDev != dstDev && copyStreamIndex == 0) {
@@ -170,3 +172,34 @@ THCudaTensor_copy(THCState* state, THCudaTensor* dst, THCudaTensor* src) {
THError(cudaGetErrorString(errcode));
}
}
+
+// conversions are mediated by the CPU
+// yes, this is slow; feel free to write CUDA kernels for this
+#define THC_CUDA_TENSOR_IMPLEMENT_COPY(TYPEC,TYPECUDA) \
+ void THCTensor_(copyCuda##TYPEC)(THCState *state, THCTensor *self, struct THCuda##TYPECUDA##Tensor *src) \
+ { \
+ if(THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \
+ THCTensor_(copy)(state, self, (THCTensor*) src); /* cast just removes compiler warning */ \
+ } else { \
+ THArgCheck(THCTensor_(nElement)(state, self) == THCuda##TYPECUDA##Tensor_nElement(state, src), 2, "size does not match"); \
+ THLongStorage *size = THCuda##TYPECUDA##Tensor_newSizeOf(state, src); \
+ TH##TYPEC##Tensor *buffer1 = TH##TYPEC##Tensor_newWithSize(size, NULL); \
+ THTensor *buffer2 = THTensor_(newWithSize)(size, NULL); \
+ TH##TYPEC##Tensor_copyCuda(state, buffer1, src); \
+ THTensor_(copy##TYPEC)(buffer2, buffer1); \
+ THCTensor_(copyCPU)(state, self, buffer2); \
+ THLongStorage_free(size); \
+ TH##TYPEC##Tensor_free(buffer1); \
+ THTensor_(free)(buffer2); \
+ } \
+ }
+
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Byte,Byte)
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Char,Char)
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Short,Short)
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Int,Int)
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Long,Long)
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Float,) // i.e. float
+THC_CUDA_TENSOR_IMPLEMENT_COPY(Double,Double)
+
+#endif