Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-09 23:50:15 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-11 11:32:08 +0300
commitd34673d67bc8cefd7055edf55a2c88d28b3ab158 (patch)
tree3ec0baef7cbe8a007bbb9b732e11f4b4491d97ae
parentc796ad19d5ce4e4bb8d88a7aa83a7c88d0557e0b (diff)
Add Infer Size N, for expansion of fused operations.
-rw-r--r--lib/TH/THStorage.c43
-rw-r--r--lib/TH/THStorage.h1
2 files changed, 44 insertions, 0 deletions
diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c
index 7636e4b..b7d4a5d 100644
--- a/lib/TH/THStorage.c
+++ b/lib/TH/THStorage.c
@@ -104,6 +104,49 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di
return 0;
}
+TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors) {
+ THArgCheck(n > 0, 2, "n must be greater than 0");
+ THArgCheck(sizes != NULL, 1, "sizesA must not be null");
+ THArgCheck(dims != NULL, 1, "dims must not be null");
+
+ ptrdiff_t ndim = 0;
+ for (int j = 0; j < n; ++j) {
+ THArgCheck(sizes[ j ] != NULL, 1, "size %d must not be null", j);
+ THArgCheck(dims[ j ], 1, "Can't expand empty tensor %d", j);
+ ptrdiff_t ndim = dims[ j ] > ndim ? dims[ j ] : ndim;
+ }
+
+ long *expandedSizes = THAlloc(sizeof(long)*ndim);
+
+ for (long i = ndim - 1; i >= 0; --i) {
+ long max_dim_size = 1;
+ long offset = ndim - 1 - i;
+ for (int j = 0; j < n; ++j) {
+ long dim = dims[ j ] - 1 - offset;
+ long size = (dim >= 0) ? sizes[ i ][ dim ] : 1;
+ if (size != max_dim_size) {
+ if (max_dim_size == 1){
+ max_dim_size = size;
+ } else if (size == 1) {
+ // we'll expand, nothing to do
+ } else {
+ THFree(expandedSizes);
+ if (raiseErrors) {
+ THError("The size of tensor %i (%ld) must match the expanded size of tensor (%ld) at "
+ "non-singleton dimension %ld.", j, size, max_dim_size, i);
+ }
+ return -1;
+ }
+ }
+ }
+ expandedSizes[ i ] = max_dim_size;
+ }
+ THLongStorage_resize(output, ndim);
+ memcpy(THLongStorage_data(output), expandedSizes, sizeof(long)*ndim);
+ THFree(expandedSizes);
+ return 0;
+}
+
TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors) {
ptrdiff_t ndim = THLongStorage_size(sizes);
diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h
index 9cbc452..7aab593 100644
--- a/lib/TH/THStorage.h
+++ b/lib/TH/THStorage.h
@@ -31,6 +31,7 @@ typedef struct {
TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size);
TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement);
TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors);
+TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors);
TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors);
#endif