Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-05-02 22:32:09 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-23 00:21:19 +0300
commit9d30c3a5a39ead4721bdeb902564c3df7356d270 (patch)
tree7ad163644f29b4fddcf2a6f546a6eac2afbb128e
parente3917893c4bdfc64e46167c4147f759166652e36 (diff)
Advanced Indexing Part 1 -- Purely Integer Array Indexing
-rw-r--r--lib/TH/generic/THTensor.c36
-rw-r--r--lib/TH/generic/THTensor.h1
2 files changed, 37 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c
index 4a8738a..e44e06e 100644
--- a/lib/TH/generic/THTensor.c
+++ b/lib/TH/generic/THTensor.c
@@ -315,6 +315,42 @@ void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes) {
THFree(expandedStrides);
}
+
+void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count) {
+ for (int i = 0; i < count; ++i) {
+ THArgCheck(THTensor_(nDimension)(ops[i]) > 0, i, "can't expand empty tensor %d", i);
+ }
+
+ long *op_sizes[count];
+ long op_dims[count];
+
+ for (int i = 0; i < count; ++i) {
+ op_sizes[i] = ops[i]->size;
+ op_dims[i] = ops[i]->nDimension;
+ }
+
+ THLongStorage *sizes = THLongStorage_new();
+ char error_buffer[1024];
+ int ret = THLongStorage_inferSizeN(sizes,
+ count,
+ op_sizes,
+ op_dims,
+ error_buffer,
+ 1024);
+
+ if(ret != 0) {
+ THLongStorage_free(sizes);
+ THError(error_buffer);
+ return;
+ }
+
+ for (int i = 0; i < count; ++i) {
+ THTensor_(expand)(rets[i], ops[i], sizes);
+ }
+
+ THLongStorage_free(sizes);
+}
+
void THTensor_(set)(THTensor *self, THTensor *src)
{
if(self != src)
diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h
index 9a2417f..9fb246c 100644
--- a/lib/TH/generic/THTensor.h
+++ b/lib/TH/generic/THTensor.h
@@ -72,6 +72,7 @@ TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size);
TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size);
TH_API void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size);
+TH_API void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count);
TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);