Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-05-02 22:32:09 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-23 00:21:30 +0300
commit9db5057877c6ffa7df59727cbada13318d7e3eaf (patch)
tree2cb490ce2b4740799e8f67f1368b2633c58e30e9
parentc394a5615794be6572efd68763d680ea70f0aad8 (diff)
Advanced Indexing Part 1 -- Purely Integer Array Indexing
-rw-r--r--lib/THC/generic/THCTensor.c35
-rw-r--r--lib/THC/generic/THCTensor.h1
2 files changed, 36 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c
index d4bb012..6a4051f 100644
--- a/lib/THC/generic/THCTensor.c
+++ b/lib/THC/generic/THCTensor.c
@@ -329,6 +329,41 @@ void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLong
THFree(expandedStrides);
}
+void THCTensor_(expandNd)(THCState *state, THCTensor **rets, THCTensor **ops, int count) {
+ for (int i = 0; i < count; ++i) {
+ THArgCheck(THCTensor_(nDimension)(state, ops[i]) > 0, i, "can't expand empty tensor %d", i);
+ }
+
+ long *op_sizes[count];
+ long op_dims[count];
+
+ for (int i = 0; i < count; ++i) {
+ op_sizes[i] = ops[i]->size;
+ op_dims[i] = ops[i]->nDimension;
+ }
+
+ THLongStorage *sizes = THLongStorage_new();
+ char error_buffer[1024];
+ int ret = THLongStorage_inferSizeN(sizes,
+ count,
+ op_sizes,
+ op_dims,
+ error_buffer,
+ 1024);
+
+ if(ret != 0) {
+ THLongStorage_free(sizes);
+ THError(error_buffer);
+ return;
+ }
+
+ for (int i = 0; i < count; ++i) {
+ THCTensor_(expand)(state, rets[i], ops[i], sizes);
+ }
+
+ THLongStorage_free(sizes);
+}
+
void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src)
{
if(self != src)
diff --git a/lib/THC/generic/THCTensor.h b/lib/THC/generic/THCTensor.h
index 8059cad..dd7e6e6 100644
--- a/lib/THC/generic/THCTensor.h
+++ b/lib/THC/generic/THCTensor.h
@@ -70,6 +70,7 @@ THC_API THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLon
THC_API THCTensor *THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THLongStorage *size);
THC_API void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes);
+THC_API void THCTensor_(expandNd)(THCState *state, THCTensor **rets, THCTensor **ops, int count);
THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride);
THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src);