diff options
author | Trevor Killeen <killeentm@gmail.com> | 2017-05-02 22:32:09 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-23 00:21:30 +0300 |
commit | 9db5057877c6ffa7df59727cbada13318d7e3eaf (patch) | |
tree | 2cb490ce2b4740799e8f67f1368b2633c58e30e9 | |
parent | c394a5615794be6572efd68763d680ea70f0aad8 (diff) |
Advanced Indexing Part 1 -- Purely Integer Array Indexing
-rw-r--r-- | lib/THC/generic/THCTensor.c | 35 | ||||
-rw-r--r-- | lib/THC/generic/THCTensor.h | 1 |
2 files changed, 36 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c index d4bb012..6a4051f 100644 --- a/lib/THC/generic/THCTensor.c +++ b/lib/THC/generic/THCTensor.c @@ -329,6 +329,41 @@ void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLong THFree(expandedStrides); } +void THCTensor_(expandNd)(THCState *state, THCTensor **rets, THCTensor **ops, int count) { + for (int i = 0; i < count; ++i) { + THArgCheck(THCTensor_(nDimension)(state, ops[i]) > 0, i, "can't expand empty tensor %d", i); + } + + long *op_sizes[count]; + long op_dims[count]; + + for (int i = 0; i < count; ++i) { + op_sizes[i] = ops[i]->size; + op_dims[i] = ops[i]->nDimension; + } + + THLongStorage *sizes = THLongStorage_new(); + char error_buffer[1024]; + int ret = THLongStorage_inferSizeN(sizes, + count, + op_sizes, + op_dims, + error_buffer, + 1024); + + if(ret != 0) { + THLongStorage_free(sizes); + THError(error_buffer); + return; + } + + for (int i = 0; i < count; ++i) { + THCTensor_(expand)(state, rets[i], ops[i], sizes); + } + + THLongStorage_free(sizes); +} + void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src) { if(self != src) diff --git a/lib/THC/generic/THCTensor.h b/lib/THC/generic/THCTensor.h index 8059cad..dd7e6e6 100644 --- a/lib/THC/generic/THCTensor.h +++ b/lib/THC/generic/THCTensor.h @@ -70,6 +70,7 @@ THC_API THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLon THC_API THCTensor *THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THLongStorage *size); THC_API void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes); +THC_API void THCTensor_(expandNd)(THCState *state, THCTensor **rets, THCTensor **ops, int count); THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride); THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src); |