Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/generic/THCTensorCopy.cu')
-rw-r--r--lib/THC/generic/THCTensorCopy.cu273
1 files changed, 33 insertions, 240 deletions
diff --git a/lib/THC/generic/THCTensorCopy.cu b/lib/THC/generic/THCTensorCopy.cu
index c5768e2..4198025 100644
--- a/lib/THC/generic/THCTensorCopy.cu
+++ b/lib/THC/generic/THCTensorCopy.cu
@@ -4,251 +4,44 @@
THC_API void
THCTensor_(copy)(THCState* state, THCTensor* dst, THCTensor* src) {
- long totalElements = THCTensor_(nElement)(state, dst);
-
- THArgCheck(totalElements == THCTensor_(nElement)(state, src), 2,
- "sizes do not match");
-
- if (THCTensor_(nDimension)(state, dst) == 0) {
- // Zero-dim tensor; copy nothing
- return;
- }
-
- // We can memcpy the memory if:
- // -both tensors are contiguous; or,
- // -there is only one element to copy; or,
- // -FIXME: if both tensors have matching size and stride arrays, and no
- // holes within (in other words, there is some permutation that can be applied
- // to the size/strides such that the resulting tensor is contiguous).
- bool srcContig = THCTensor_(isContiguous)(state, src);
- bool dstContig = THCTensor_(isContiguous)(state, dst);
- bool memcpyEligible = (srcContig && dstContig) || (totalElements == 1);
-
- int srcDev = THCTensor_(getDevice)(state, src);
- int dstDev = THCTensor_(getDevice)(state, dst);
- int oldDev = curGPU();
-
- // We always perform the copy on the source device, using the
- // current stream on the source device.
- // If the copy is on the default stream, then we fully synchronize
- // both src and dst's default streams for completion of the
- // copy. We have to explicitly do this for non-contig copies.
- // This mimics the behavior of cross-device cudaMemcpyAsync on
- // the default stream.
- // If the copy is not on the default stream, then it is up to the
- // user to add needed synchronization on the dst device, since the
- // stream on the dst device that wishes to synchronize may not be
- // the same index as the one on the src device.
- int copyStreamIndex =
- THCState_getCurrentStreamIndex(state);
- cudaStream_t copyStream =
- THCState_getDeviceStream(state, srcDev, copyStreamIndex);
-
- if (srcDev != dstDev && copyStreamIndex == 0) {
- // This is a cross-device copy on the default stream. We perform a
- // two-way barrier between both devices' default streams before
- // the copy. This ensures that any write-after-write and
- // write-after-read dependencies on the destination side are
- // handled, so that no one is operating on the dst memory when
- // we perform the copy.
- // src waits on dst barrier (src already waits on src)
- cudaEvent_t dstReady;
- THCudaCheck(cudaSetDevice(dstDev));
- THCudaCheck(cudaEventCreateWithFlags(&dstReady, cudaEventDisableTiming));
- THCudaCheck(cudaEventRecord(dstReady, NULL));
-
- THCudaCheck(cudaSetDevice(srcDev));
- THCudaCheck(cudaStreamWaitEvent(NULL, dstReady, 0));
- THCudaCheck(cudaEventDestroy(dstReady));
- } else if (srcDev != oldDev) {
- THCudaCheck(cudaSetDevice(srcDev));
- }
-
- // We are now on srcDev
- if (memcpyEligible) {
- // Perform the copy
- THCudaCheck(cudaMemcpyAsync(THCTensor_(data)(state, dst),
- THCTensor_(data)(state, src),
- totalElements * sizeof(real),
- cudaMemcpyDeviceToDevice,
- copyStream));
- } else {
-#if defined(THC_REAL_IS_FLOAT)
- // Non-contiguous copy
-
- // We avoid creating temporary memory copies if possible.
- // If both src and dst are on the same device, or if they are on
- // different devices and p2p access is enabled, perform the copy
- // by a pointwise copy kernel.
- // Otherwise, we'll have to make contiguous (which will in fact
- // invoke copy() again), and then perform the copy.
- // FIXME: might want to consider only running the pointwise kernel
- // if both src and dst innermost dimensions are contiguous. If
- // they are not, then taking the hit of the memory allocation/free
- // might be worth it to avoid non-coalesced reads or writes.
-
- // A device always has access to itself, so this also handles the
- // case srcDev == dstDev
- if (THCState_getPeerToPeerAccess(state, srcDev, dstDev)) {
- // Make sure we have the current stream set in THCState, since
- // pointwise uses that
- if (srcDev != oldDev) {
- THCState_setStream(state, srcDev, copyStreamIndex);
- }
-
- bool succ =
- THCudaTensor_pointwiseApply2(state, dst, src, CopyOp<float>());
- THArgCheck(succ, 2, CUTORCH_DIM_WARNING);
-
- // Restore prior THCState stream
- if (srcDev != oldDev) {
- THCState_setStream(state, oldDev, copyStreamIndex);
- }
- } else {
- // GPUs can't access each other directly; fall back to
- // newContiguous and memcpy
- THCudaTensor* srcContig = THCudaTensor_newContiguous(state, src);
- THCudaTensor* dstContig = dst;
-
- if (!THCudaTensor_isContiguous(state, dst)) {
- // We are copying over the contents of dst, so we don't need
- // to preserve its values. We just need a destination tensor
- // the same size as dst.
-
- // Allocate the tensor on the new device
- THCudaCheck(cudaSetDevice(dstDev));
-
- dstContig = THCudaTensor_new(state);
- THCudaTensor_resizeAs(state, dstContig, dst);
-
- THCudaCheck(cudaSetDevice(srcDev));
- }
-
- THCudaCheck(cudaMemcpyAsync(THCudaTensor_data(state, dstContig),
- THCudaTensor_data(state, srcContig),
- totalElements * sizeof(float),
- cudaMemcpyDeviceToDevice,
- copyStream));
-
- THCudaTensor_free(state, srcContig);
-
- if (dst != dstContig) {
- THCudaTensor_freeCopyTo(state, dstContig, dst);
- }
- }
-#else
-#define STRINGIFY(x) #x
- THError("Non-contiguous copy not implemented for Cuda%sTensor", STRINGIFY(Real));
-#undef STRINGIFY
-#endif
- }
-
- if (srcDev != dstDev && copyStreamIndex == 0) {
- // dst waits on src barrier (dst already waits on dst). We cannot
- // operate on dst's copy until the copy is complete.
-
- // Still on srcDev, record default stream event
- cudaEvent_t srcReady;
- THCudaCheck(cudaEventCreateWithFlags(&srcReady, cudaEventDisableTiming));
- THCudaCheck(cudaEventRecord(srcReady, NULL));
-
- THCudaCheck(cudaSetDevice(dstDev));
- THCudaCheck(cudaStreamWaitEvent(NULL, srcReady, 0));
- THCudaCheck(cudaEventDestroy(srcReady));
-
- // We are now on dstDev (right above). Restore prior device from dst
- if (dstDev != oldDev) {
- THCudaCheck(cudaSetDevice(oldDev));
- }
- } else {
- // We are still on srcDev. Restore prior device from src
- if (srcDev != oldDev) {
- THCudaCheck(cudaSetDevice(oldDev));
- }
- }
+ THC_copyTensor<THCTensor, THCTensor>(state, dst, src);
+}
- cudaError errcode = cudaGetLastError();
- if (errcode != cudaSuccess) {
- THError(cudaGetErrorString(errcode));
- }
+THC_API void
+THCTensor_(copyIgnoringOverlaps)(THCState* state, THCTensor* dst, THCTensor* src) {
+ // Called when we are copying into an overlapping index `dst`, but
+ // we don't care which writer wins. Hacky but it works.
+ // This is itself invoked by pointwiseApply2 / THCTensor_copy in
+ // case that there are write overlaps.
+ // FIXME: really, overlapping writes should be illegal/an error in Torch
+ THC_pointwiseApply2(
+ state, dst, src,
+ CopyOp<typename TensorUtils<THCTensor>::DataType,
+ typename TensorUtils<THCTensor>::DataType>(),
+ ReadOnly, /* ignore overwrites */
+ ReadOnly);
}
-// conversions are mediated by the CPU
-// yes, this is slow; feel free to write CUDA kernels for this
-#ifndef THC_REAL_IS_HALF
-#define THC_CUDA_TENSOR_IMPLEMENT_COPY(TYPEC,TYPECUDA) \
- void THCTensor_(copyCuda##TYPEC)(THCState *state, THCTensor *self, struct THCuda##TYPECUDA##Tensor *src) \
- { \
- if(THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \
- THCTensor_(copy)(state, self, (THCTensor*) src); /* cast just removes compiler warning */ \
- } else { \
- THArgCheck(THCTensor_(nElement)(state, self) == THCuda##TYPECUDA##Tensor_nElement(state, src), 2, "size does not match"); \
- THLongStorage *size = THCuda##TYPECUDA##Tensor_newSizeOf(state, src); \
- TH##TYPEC##Tensor *buffer1 = TH##TYPEC##Tensor_newWithSize(size, NULL); \
- THTensor *buffer2 = THTensor_(newWithSize)(size, NULL); \
- TH##TYPEC##Tensor_copyCuda(state, buffer1, src); \
- THTensor_(copy##TYPEC)(buffer2, buffer1); \
- THCTensor_(copyCPU)(state, self, buffer2); \
- THLongStorage_free(size); \
- TH##TYPEC##Tensor_free(buffer1); \
- THTensor_(free)(buffer2); \
- } \
- }
-#else
-#define THC_CUDA_TENSOR_IMPLEMENT_COPY(TYPEC,TYPECUDA) \
- void THCTensor_(copyCuda##TYPEC)(THCState *state, THCTensor *self, struct THCuda##TYPECUDA##Tensor *src) \
- { \
- THArgCheck(THCTensor_(nElement)(state, self) == THCuda##TYPECUDA##Tensor_nElement(state, src), 2, "size does not match"); \
- if (THCTypeIdx_(TYPEC) == THCTypeIdxFloat) { \
- THCudaTensor *csrc = THCudaTensor_newContiguous(state, (THCudaTensor*) src); /* cast removes compiler error */ \
- THCFloat2Half(state, \
- THCTensor_(data)(state, self), \
- THCudaTensor_data(state, csrc), \
- THCudaTensor_nElement(state, csrc)); \
- THCudaTensor_free(state, csrc); \
- } else { \
- THLongStorage *size = THCuda##TYPECUDA##Tensor_newSizeOf(state, src); \
- THCudaTensor *buffer = THCudaTensor_newWithSize(state, size, NULL); \
- THCudaTensor_copyCuda##TYPEC(state, buffer, src); \
- THCFloat2Half(state, \
- THCTensor_(data)(state, self), \
- THCudaTensor_data(state, buffer), \
- THCudaTensor_nElement(state, buffer)); \
- THCudaTensor_free(state, buffer); \
- THLongStorage_free(size); \
- } \
+#define IMPLEMENT_THC_CUDA_TENSOR_COPY(TYPEC, TYPECUDA) \
+ THC_API void \
+ THCTensor_(copyCuda##TYPEC)(THCState *state, \
+ THCTensor *self, \
+ THCuda##TYPECUDA##Tensor *src) { \
+ THC_copyTensor<THCTensor, THCuda##TYPECUDA##Tensor>(state, self, src); \
}
-#endif
-
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Byte,Byte)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Char,Char)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Short,Short)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Int,Int)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Long,Long)
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Float,) // i.e. float
-THC_CUDA_TENSOR_IMPLEMENT_COPY(Double,Double)
-#if CUDA_VERSION >= 7050
-#define FLOAT_COPY(TYPE) TH_CONCAT_3(TH, CReal, Tensor_copyCudaFloat)
-void THCTensor_(copyCudaHalf)(THCState *state, THCTensor *self, struct THCudaHalfTensor *src)
-{
- if(THCTypeIdx_(Real) == THCTypeIdxHalf) {
- THCTensor_(copy)(state, self, (THCTensor*) src); /* cast removes compiler error */
- } else {
- THArgCheck(THCTensor_(nElement)(state, self) == THCudaHalfTensor_nElement(state, src), 2, "size does not match");
- src = THCudaHalfTensor_newContiguous(state, src);
- THLongStorage *size = THCudaHalfTensor_newSizeOf(state, src);
- THCudaTensor *buffer = THCudaTensor_newWithSize(state, size, NULL);
- THCHalf2Float(state, THCudaTensor_data(state, buffer), THCudaHalfTensor_data(state, src), THCudaHalfTensor_nElement(state, src));
- FLOAT_COPY(Real)(state, self, buffer);
- THCudaTensor_free(state, buffer);
- THCudaHalfTensor_free(state, src);
- THLongStorage_free(size);
- }
-}
-#undef FLOAT_COPY
-#endif // CUDA_VERSION >= 7050
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Byte, Byte)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Char, Char)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Short, Short)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Int, Int)
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Long, Long)
+// THCudaTensor aka the non-existent THCudaFloatTensor
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Float, )
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Double, Double)
+#ifdef CUDA_HALF_TENSOR
+IMPLEMENT_THC_CUDA_TENSOR_COPY(Half, Half)
+#endif
-#undef THC_CUDA_TENSOR_IMPLEMENT_COPY
+#undef IMPLEMENT_THC_CUDA_TENSOR_COPY
#endif