Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-07-18 23:24:04 +0300
committerSoumith Chintala <soumith@gmail.com>2017-07-19 18:04:49 +0300
commit5227de5b0fc14d77de227b3d9e58fd1d0355740c (patch)
treed25eed75ab273d40b58cd1ea7848814a96188f0c
parenta0799aaaf809985a3be272c5d0e482d5f9d04136 (diff)
move to model with cuda indexing tensors for cuda tensor adv indexing
-rw-r--r--lib/THC/THCTensorIndex.cu3
-rw-r--r--lib/THC/generic/THCTensorIndex.cu1
2 files changed, 4 insertions, 0 deletions
diff --git a/lib/THC/THCTensorIndex.cu b/lib/THC/THCTensorIndex.cu
index e58caa0..41231b0 100644
--- a/lib/THC/THCTensorIndex.cu
+++ b/lib/THC/THCTensorIndex.cu
@@ -334,6 +334,8 @@ __global__ void indexSelectLargeIndex(TensorInfo<T, IndexType> dst,
template <typename IndexType, unsigned int Dims>
struct LinearIndexCalcData {
+ // sizes for the Tensor dims (from the Tensor, for bounds checking)
+ IndexType baseSizes[Dims];
// sizes for Tensor dims (either from the Tensor, or the size of the adv indexer at that dim)
IndexType sizes[Dims];
// strides for the Tensor we are indexing into
@@ -373,6 +375,7 @@ __device__ __forceinline__ long calculateOffset(
indexAtDim = index - nextIndex * sizeAtDim;
}
+ assert(indexAtDim < data.baseSizes[dim]);
offset += indexAtDim * strideAtDim;
index = nextIndex;
}
diff --git a/lib/THC/generic/THCTensorIndex.cu b/lib/THC/generic/THCTensorIndex.cu
index a9ed28e..ac42cf5 100644
--- a/lib/THC/generic/THCTensorIndex.cu
+++ b/lib/THC/generic/THCTensorIndex.cu
@@ -535,6 +535,7 @@ void THCTensor_(calculateAdvancedIndexingOffsets)(
{ \
LinearIndexCalcData<INDEX_TYPE, DIMS> data; \
for (int i = 0; i < DIMS; ++i) { \
+ data.baseSizes[i] = THCTensor_(size)(state, indexed, i); \
data.sizes[i] = indexers[i] != NULL ? \
THCudaLongTensor_nElement(state, indexers[i]) : \
THCTensor_(size)(state, indexed, i); \