diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-04-26 23:34:56 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-11 11:19:37 +0300 |
commit | c796ad19d5ce4e4bb8d88a7aa83a7c88d0557e0b (patch) | |
tree | 1b9accca7453f7d2a558922959efbfe6c66b3c64 | |
parent | 036989cc99b48ef4ecbf7604ab46f8461dd264fb (diff) |
Expand improvements
1) Rename calculateExpandGeometry to inferExpandGeometry for consistency
2) Simplify inferExpandGeometry implementation by using a single pass
through dimensions
3) Implement a two operand expansion, expand2.
4) Implement versions that return error code to use for fallback to
equal nElem support.
-rw-r--r-- | lib/TH/THStorage.c | 83 | ||||
-rw-r--r-- | lib/TH/THStorage.h | 3 | ||||
-rw-r--r-- | lib/TH/generic/THTensor.c | 73 | ||||
-rw-r--r-- | lib/TH/generic/THTensor.h | 5 |
4 files changed, 132 insertions, 32 deletions
diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c index cea0f95..7636e4b 100644 --- a/lib/TH/THStorage.c +++ b/lib/TH/THStorage.c @@ -66,39 +66,76 @@ TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t return copy; } -TH_API void THLongStorage_calculateExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est) { - ptrdiff_t ndim = THLongStorage_size(sizes); - long numUnsqueezed = ndim - tensorDim; +TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors) { + THArgCheck(sizesA != NULL, 1, "sizesA must not be null"); + THArgCheck(sizesB != NULL, 2, "sizesB must not be null"); + THArgCheck(dimsA, 1, "Can't expand empty tensor a"); + THArgCheck(dimsB, 1, "Can't expand empty tensor b"); + ptrdiff_t ndim = dimsA > dimsB ? dimsA : dimsB; long *expandedSizes = THAlloc(sizeof(long)*ndim); - long *expandedStrides = THAlloc(sizeof(long)*ndim); - for (long i = numUnsqueezed; i < ndim; ++i) { - expandedSizes[i] = tensorSizes[i - numUnsqueezed]; - expandedStrides[i] = tensorStrides[i - numUnsqueezed]; + for (long i = ndim - 1; i >= 0; --i) { + long offset = ndim - 1 - i; + long dimA = dimsA - 1 - offset; + long dimB = dimsB - 1 - offset; + long sizeA = (dimA >= 0) ? sizesA[dimA] : 1; + long sizeB = (dimB >= 0) ? sizesB[dimB] : 1; + if (sizeA != sizeB) { + if (sizeA == 1) { + sizeA = sizeB; + } + else if (sizeB == 1) { + } + else { + THFree(expandedSizes); + if (raiseErrors) { + THError("The size of tensor a (%ld) must match the size of tensor b (%ld) at " + "non-singleton dimension %ld.", sizeA, sizeB, i); + } + return -1; + } + } + expandedSizes[ i ] = sizeA; } + THLongStorage_resize(output, ndim); + memcpy(THLongStorage_data(output), expandedSizes, sizeof(long)*ndim); + THFree(expandedSizes); + return 0; +} - for (long i = numUnsqueezed - 1; i > -1; --i) { - expandedSizes[i] = 1; - expandedStrides[i] = expandedSizes[i+1] * expandedStrides[i+1]; - } +TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors) { + ptrdiff_t ndim = THLongStorage_size(sizes); + + long *expandedSizes = THAlloc(sizeof(long)*ndim); + long *expandedStrides = THAlloc(sizeof(long)*ndim); - // create a new geometry for the tensor - for (long i = 0; i < ndim; ++i) { - long size = expandedSizes[i]; + // create a new geometry for the tensors + for (long i = ndim - 1; i >= 0; --i) { + long offset = ndim - 1 - i; + long dim = tensorDim - 1 - offset; + long size = (dim >= 0) ? tensorSizes[dim] : 1; + long stride = (dim >= 0) ? + tensorStrides[dim] : expandedSizes[i + 1] * expandedStrides[i+1]; long targetSize = THLongStorage_data(sizes)[i]; - if (size == 1) { - if (targetSize != 1) { - expandedSizes[i] = targetSize; - expandedStrides[i] = 0; + if (size != targetSize) { + if (size == 1) { + size = targetSize; + stride = 0; + } else { + THFree(expandedSizes); + THFree(expandedStrides); + if (raiseErrors) { + THError("The expanded size of the tensor (%d) must match the existing size (%d) at " + "non-singleton dimension %ld.", targetSize, size, i); + } + return -1; } - } else if (size != targetSize) { - THFree(expandedSizes); - THFree(expandedStrides); - THError("The expanded size of the tensor (%d) must match the existing size (%d) at \ - non-singleton dimension %ld.", targetSize, size, i); } + expandedSizes[i] = size; + expandedStrides[i] = stride; } *esz = expandedSizes; *est = expandedStrides; + return 0; } diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h index f81926c..9cbc452 100644 --- a/lib/TH/THStorage.h +++ b/lib/TH/THStorage.h @@ -30,6 +30,7 @@ typedef struct { TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size); TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement); -TH_API void THLongStorage_calculateExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est); +TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors); +TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors); #endif diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index eab7231..4b5c302 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -285,21 +285,80 @@ void THTensor_(resize5d)(THTensor *self, long size0, long size1, long size2, lon THTensor_(resizeNd)(self, 5, size, NULL); } -THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes) { - THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1, "the number of sizes provided \ - must be greater or equal to the number of dimensions in the tensor"); +THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raiseErrors) { + THTensor *result = THTensor_(new)(); + THTensor_(expand)(result, tensor, sizes, raiseErrors); + return result; +} + +int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes, int raiseErrors) { + if (raiseErrors) { + THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1, + "the number of sizes provided must be greater or equal to the " + "number of dimensions in the tensor"); + } THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor"); long *expandedSizes; long *expandedStrides; - THLongStorage_calculateExpandGeometry(tensor->size, tensor->stride, THTensor_(nDimension)(tensor), sizes, &expandedSizes, &expandedStrides); + int ret = + THLongStorage_inferExpandGeometry(tensor->size, tensor->stride, THTensor_(nDimension)(tensor), sizes, &expandedSizes, &expandedStrides, raiseErrors); + if (ret != 0) { + return ret; + } + THTensor_(setStorageNd)(r, THTensor_(storage)(tensor), THTensor_(storageOffset)(tensor), THLongStorage_size(sizes), expandedSizes, expandedStrides); + THFree(expandedSizes); + THFree(expandedStrides); - THTensor *result = THTensor_(new)(); - THTensor_(setStorageNd)(result, THTensor_(storage)(tensor), THTensor_(storageOffset)(tensor), THLongStorage_size(sizes), expandedSizes, expandedStrides); + return 0; +} + +int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors) { + THArgCheck(THTensor_(nDimension)(opa) > 0, 0, "can't expand empty tensor opa"); + THArgCheck(THTensor_(nDimension)(opb) > 0, 0, "can't expand empty tensor opb"); + + THLongStorage *sizes = THLongStorage_new(); + int ret = THLongStorage_inferSize2(sizes, + opa->size, THTensor_(nDimension)(opa), + opb->size, THTensor_(nDimension)(opb), + raiseErrors); + if(ret != 0) { + return ret; + } + + long *expandedSizes; + long *expandedStrides; + ret = THLongStorage_inferExpandGeometry(opa->size, opa->stride, + THTensor_(nDimension)(opa), sizes, + &expandedSizes, &expandedStrides, + raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + + THTensor_(setStorageNd)(ra, + THTensor_(storage)(opa), + THTensor_(storageOffset)(opa), + THLongStorage_size(sizes), + expandedSizes, + expandedStrides); THFree(expandedSizes); THFree(expandedStrides); - return result; + ret = THLongStorage_inferExpandGeometry(opb->size, opb->stride, + THTensor_(nDimension)(opb), sizes, + &expandedSizes, &expandedStrides, + raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + + THTensor_(setStorageNd)(rb, + THTensor_(storage)(opb), + THTensor_(storageOffset)(opb), + THLongStorage_size(sizes), + expandedSizes, + expandedStrides); + THFree(expandedSizes); + THFree(expandedStrides); + THLongStorage_free(sizes); + return 0; } void THTensor_(set)(THTensor *self, THTensor *src) diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h index 8754796..bafe00a 100644 --- a/lib/TH/generic/THTensor.h +++ b/lib/TH/generic/THTensor.h @@ -69,7 +69,10 @@ TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long fir TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_); TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_); TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size); -TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size); +TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size, int raiseErrors); + +TH_API int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size, int raiseErrors); +TH_API int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors); TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride); TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src); |