diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-06-05 17:51:58 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-11 11:32:08 +0300 |
commit | 52109da83e460c7671f4f27eea843c35f5b2ea56 (patch) | |
tree | a65fe33257c0080a0da3ab7e7ef69660da7e06eb | |
parent | 67aa876ef3afdcae972d7999db9a8f4d5d42ec98 (diff) |
Incorporate review comments:
1) Line up trailing dimensions in broadcast docs.
2) remove unnecessary expand_as in common_nn test.
3) use view in tensor_str instead of resize_.
4) newExpand remove raiseErrors change.
5) clarify expandedSizes/expandedStrides parameters in inferExpandGeometry.
6) simplify inferSize2/inferSizeN implementations.
7) use new-style classes for warning.
-rw-r--r-- | lib/TH/THStorage.c | 68 | ||||
-rw-r--r-- | lib/TH/THStorage.h | 4 | ||||
-rw-r--r-- | lib/TH/generic/THTensor.c | 4 | ||||
-rw-r--r-- | lib/TH/generic/THTensor.h | 2 |
4 files changed, 35 insertions, 43 deletions
diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c index 54ecc84..b1a2a93 100644 --- a/lib/TH/THStorage.c +++ b/lib/TH/THStorage.c @@ -81,22 +81,16 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di long dimB = dimsB - 1 - offset; long sizeA = (dimA >= 0) ? sizesA[dimA] : 1; long sizeB = (dimB >= 0) ? sizesB[dimB] : 1; - if (sizeA != sizeB) { - if (sizeA == 1) { - sizeA = sizeB; - } - else if (sizeB == 1) { - } - else { - THFree(expandedSizes); - if (raiseErrors) { - THError("The size of tensor a (%ld) must match the size of tensor b (%ld) at " - "non-singleton dimension %ld.", sizeA, sizeB, i); - } - return -1; + if (sizeA == sizeB || sizeA == 1 || sizeB == 1) { + expandedSizes[i] = THMax(sizeA, sizeB); + } else { + THFree(expandedSizes); + if (raiseErrors) { + THError("The size of tensor a (%ld) must match the size of tensor b (%ld) at " + "non-singleton dimension %ld.", sizeA, sizeB, i); } + return -1; } - expandedSizes[ i ] = sizeA; } THLongStorage_resize(output, ndim); memcpy(THLongStorage_data(output), expandedSizes, sizeof(long)*ndim); @@ -119,28 +113,22 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *expandedSizes = THAlloc(sizeof(long)*ndim); for (long i = ndim - 1; i >= 0; --i) { - long max_dim_size = 1; + expandedSizes[ i ] = 1; long offset = ndim - 1 - i; for (int j = 0; j < n; ++j) { long dim = dims[ j ] - 1 - offset; long size = (dim >= 0) ? sizes[ j ][ dim ] : 1; - if (size != max_dim_size) { - if (max_dim_size == 1){ - max_dim_size = size; - } else if (size == 1) { - // we'll expand, nothing to do - } else { - THFree(expandedSizes); - if (raiseErrors) { - THError("The size of tensor %i (%ld) must match the expanded size of tensor (%ld) at " - "non-singleton dimension %ld.", j, size, max_dim_size, i); - } - return -1; + if (size == expandedSizes[ i ] || size == 1 || expandedSizes[ i ] == 1) { + expandedSizes[ i ] = THMax(expandedSizes[ i ], size); + } else { + THFree(expandedSizes); + if (raiseErrors) { + THError("The size of tensor %i (%ld) must match the expanded size of tensor (%ld) at " + "non-singleton dimension %ld.", j, size, expandedSizes[ i ], i); } + return -1; } } - - expandedSizes[ i ] = max_dim_size; } THLongStorage_resize(output, ndim); memcpy(THLongStorage_data(output), expandedSizes, sizeof(long)*ndim); @@ -148,11 +136,13 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, return 0; } -TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors) { +TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, + THLongStorage *sizes, long **expandedSizes, long **expandedStrides, + int raiseErrors) { ptrdiff_t ndim = THLongStorage_size(sizes); - long *expandedSizes = THAlloc(sizeof(long)*ndim); - long *expandedStrides = THAlloc(sizeof(long)*ndim); + long *expandedSizesCalc = THAlloc(sizeof(long)*ndim); + long *expandedStridesCalc = THAlloc(sizeof(long)*ndim); // create a new geometry for the tensors for (long i = ndim - 1; i >= 0; --i) { @@ -160,15 +150,15 @@ TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStri long dim = tensorDim - 1 - offset; long size = (dim >= 0) ? tensorSizes[dim] : 1; long stride = (dim >= 0) ? - tensorStrides[dim] : expandedSizes[i + 1] * expandedStrides[i+1]; + tensorStrides[dim] : expandedSizesCalc[i + 1] * expandedStridesCalc[i+1]; long targetSize = THLongStorage_data(sizes)[i]; if (size != targetSize) { if (size == 1) { size = targetSize; stride = 0; } else { - THFree(expandedSizes); - THFree(expandedStrides); + THFree(expandedSizesCalc); + THFree(expandedStridesCalc); if (raiseErrors) { THError("The expanded size of the tensor (%d) must match the existing size (%d) at " "non-singleton dimension %ld.", targetSize, size, i); @@ -176,10 +166,10 @@ TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStri return -1; } } - expandedSizes[i] = size; - expandedStrides[i] = stride; + expandedSizesCalc[i] = size; + expandedStridesCalc[i] = stride; } - *esz = expandedSizes; - *est = expandedStrides; + *expandedSizes = expandedSizesCalc; + *expandedStrides = expandedStridesCalc; return 0; } diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h index c8ab484..0b61ff8 100644 --- a/lib/TH/THStorage.h +++ b/lib/TH/THStorage.h @@ -35,6 +35,8 @@ TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors); TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors); -TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors); +TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, + THLongStorage *sizes, long **expandedSizes, long **expandedStrides, + int raiseErrors); #endif diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index c2f7359..19cd46a 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -285,9 +285,9 @@ void THTensor_(resize5d)(THTensor *self, long size0, long size1, long size2, lon THTensor_(resizeNd)(self, 5, size, NULL); } -THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raiseErrors) { +THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes) { THTensor *result = THTensor_(new)(); - THTensor_(expand)(result, tensor, sizes, raiseErrors); + THTensor_(expand)(result, tensor, sizes, 1); return result; } diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h index bc33ed5..27a2cd1 100644 --- a/lib/TH/generic/THTensor.h +++ b/lib/TH/generic/THTensor.h @@ -69,7 +69,7 @@ TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long fir TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_); TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_); TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size); -TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size, int raiseErrors); +TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size); TH_API int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size, int raiseErrors); TH_API int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors); |