From 26629b188ad6c7c9a175f9f6115125948a765eb0 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 10 May 2017 11:33:24 -0700 Subject: Support "fused" ops: addcmul/addcdiv. --- lib/TH/THStorage.c | 7 ++--- lib/TH/generic/THTensor.c | 69 ++++++++++++++++++++++++++++------------------- lib/TH/generic/THTensor.h | 1 + 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c index b7d4a5d..54ecc84 100644 --- a/lib/TH/THStorage.c +++ b/lib/TH/THStorage.c @@ -106,14 +106,14 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors) { THArgCheck(n > 0, 2, "n must be greater than 0"); - THArgCheck(sizes != NULL, 1, "sizesA must not be null"); + THArgCheck(sizes != NULL, 1, "sizes must not be null"); THArgCheck(dims != NULL, 1, "dims must not be null"); ptrdiff_t ndim = 0; for (int j = 0; j < n; ++j) { THArgCheck(sizes[ j ] != NULL, 1, "size %d must not be null", j); THArgCheck(dims[ j ], 1, "Can't expand empty tensor %d", j); - ptrdiff_t ndim = dims[ j ] > ndim ? dims[ j ] : ndim; + ndim = dims[ j ] > ndim ? dims[ j ] : ndim; } long *expandedSizes = THAlloc(sizeof(long)*ndim); @@ -123,7 +123,7 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long offset = ndim - 1 - i; for (int j = 0; j < n; ++j) { long dim = dims[ j ] - 1 - offset; - long size = (dim >= 0) ? sizes[ i ][ dim ] : 1; + long size = (dim >= 0) ? sizes[ j ][ dim ] : 1; if (size != max_dim_size) { if (max_dim_size == 1){ max_dim_size = size; @@ -139,6 +139,7 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, } } } + expandedSizes[ i ] = max_dim_size; } THLongStorage_resize(output, ndim); diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index 4b5c302..38bbffd 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -292,12 +292,14 @@ THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raise } int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes, int raiseErrors) { + THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor"); if (raiseErrors) { THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1, "the number of sizes provided must be greater or equal to the " "number of dimensions in the tensor"); + } else if (THLongStorage_size(sizes) < THTensor_(nDimension)(tensor)) { + return -1; } - THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor"); long *expandedSizes; long *expandedStrides; @@ -326,41 +328,54 @@ int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, return ret; } - long *expandedSizes; - long *expandedStrides; - ret = THLongStorage_inferExpandGeometry(opa->size, opa->stride, - THTensor_(nDimension)(opa), sizes, - &expandedSizes, &expandedStrides, - raiseErrors); + ret = THTensor_(expand)(ra, opa, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THTensor_(expand)(rb, opb, sizes, raiseErrors); THAssert(ret == 0); // since we inferred this already, it must be valid - THTensor_(setStorageNd)(ra, - THTensor_(storage)(opa), - THTensor_(storageOffset)(opa), - THLongStorage_size(sizes), - expandedSizes, - expandedStrides); - THFree(expandedSizes); - THFree(expandedStrides); + THLongStorage_free(sizes); + return 0; +} + +TH_API int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc, int raiseErrors) { + THArgCheck(THTensor_(nDimension)(opa) > 0, 0, "can't expand empty tensor opa"); + THArgCheck(THTensor_(nDimension)(opb) > 0, 0, "can't expand empty tensor opb"); + THArgCheck(THTensor_(nDimension)(opc) > 0, 0, "can't expand empty tensor opc"); + + const int op_n = 3; + long **op_sizes = THAlloc(sizeof(long**)*op_n); + long *op_dims = THAlloc(sizeof(long*)*op_n); - ret = THLongStorage_inferExpandGeometry(opb->size, opb->stride, - THTensor_(nDimension)(opb), sizes, - &expandedSizes, &expandedStrides, - raiseErrors); + op_sizes[ 0 ] = opa->size; + op_sizes[ 1 ] = opb->size; + op_sizes[ 2 ] = opc->size; + op_dims[ 0 ] = opa->nDimension; + op_dims[ 1 ] = opb->nDimension; + op_dims[ 2 ] = opc->nDimension; + + THLongStorage *sizes = THLongStorage_new(); + int ret = THLongStorage_inferSizeN(sizes, + op_n, + op_sizes, + op_dims, + raiseErrors); + + if(ret != 0) { + return ret; + } + + ret = THTensor_(expand)(ra, opa, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THTensor_(expand)(rb, opb, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THTensor_(expand)(rc, opc, sizes, raiseErrors); THAssert(ret == 0); // since we inferred this already, it must be valid - THTensor_(setStorageNd)(rb, - THTensor_(storage)(opb), - THTensor_(storageOffset)(opb), - THLongStorage_size(sizes), - expandedSizes, - expandedStrides); - THFree(expandedSizes); - THFree(expandedStrides); THLongStorage_free(sizes); return 0; } + void THTensor_(set)(THTensor *self, THTensor *src) { if(self != src) diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h index bafe00a..bc33ed5 100644 --- a/lib/TH/generic/THTensor.h +++ b/lib/TH/generic/THTensor.h @@ -73,6 +73,7 @@ TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size, int TH_API int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size, int raiseErrors); TH_API int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors); +TH_API int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc, int raiseErrors); TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride); TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src); -- cgit v1.2.3