diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-05-10 21:33:24 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-11 11:33:54 +0300 |
commit | 99dbaecbe142c4eaa133b9d476ed1816d55b91be (patch) | |
tree | e5f114974963cb64f89ec6714288a08bccff1706 | |
parent | 4817e53af421d4921effd9ab11ccd24c44c5eeca (diff) |
Support "fused" ops: addcmul/addcdiv.
-rw-r--r-- | lib/THC/generic/THCTensor.c | 78 | ||||
-rw-r--r-- | lib/THC/generic/THCTensor.h | 1 |
2 files changed, 43 insertions, 36 deletions
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c index ccd6277..57ff028 100644 --- a/lib/THC/generic/THCTensor.c +++ b/lib/THC/generic/THCTensor.c @@ -303,12 +303,14 @@ THCTensor* THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THLongStora int THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes, int raiseErrors) { + THArgCheck(THCTensor_(nDimension)(state, tensor) > 0, 0, "can't expand an empty tensor"); if (raiseErrors) { THArgCheck(THLongStorage_size(sizes) >= THCTensor_(nDimension)(state, tensor), 1, "the number of sizes provided must be greater or equal to the " "number of dimensions in the tensor"); + } else if (THLongStorage_size(sizes) < THCTensor_(nDimension)(state, tensor)) { + return -1; } - THArgCheck(THCTensor_(nDimension)(state, tensor) > 0, 0, "can't expand an empty tensor"); long *expandedSizes; long *expandedStrides; @@ -343,46 +345,50 @@ int THCTensor_(expand2)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor if(ret != 0) { return ret; } - - long *expandedSizes; - long *expandedStrides; - ret = THLongStorage_inferExpandGeometry(opa->size, - opa->stride, - THCTensor_(nDimension)(state, opa), - sizes, - &expandedSizes, - &expandedStrides, - raiseErrors); + ret = THCTensor_(expand)(state, ra, opa, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THCTensor_(expand)(state, rb, opb, sizes, raiseErrors); THAssert(ret == 0); // since we inferred this already, it must be valid - THCTensor_(setStorageNd)(state, - ra, - THCTensor_(storage)(state, opa), - THCTensor_(storageOffset)(state, opa), - THLongStorage_size(sizes), - expandedSizes, - expandedStrides); - THFree(expandedSizes); - THFree(expandedStrides); + THLongStorage_free(sizes); + return 0; +} + +THC_API int THCTensor_(expand3)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *rc, THCTensor *opa, THCTensor *opb, THCTensor *opc, int raiseErrors) { + THArgCheck(THCTensor_(nDimension)(state, opa) > 0, 0, "can't expand empty tensor opa"); + THArgCheck(THCTensor_(nDimension)(state, opb) > 0, 0, "can't expand empty tensor opb"); + THArgCheck(THCTensor_(nDimension)(state, opc) > 0, 0, "can't expand empty tensor opc"); - ret = THLongStorage_inferExpandGeometry(opb->size, - opb->stride, - THCTensor_(nDimension)(state, opb), - sizes, - &expandedSizes, - &expandedStrides, - raiseErrors); + const int op_n = 3; + long **op_sizes = THAlloc(sizeof(long**)*op_n); + long *op_dims = THAlloc(sizeof(long*)*op_n); + + op_sizes[ 0 ] = opa->size; + op_sizes[ 1 ] = opb->size; + op_sizes[ 2 ] = opc->size; + op_dims[ 0 ] = opa->nDimension; + op_dims[ 1 ] = opb->nDimension; + op_dims[ 2 ] = opc->nDimension; + + THLongStorage *sizes = THLongStorage_new(); + int ret = THLongStorage_inferSizeN(sizes, + op_n, + op_sizes, + op_dims, + raiseErrors); + + if(ret != 0) { + return ret; + } + + ret = THCTensor_(expand)(state, ra, opa, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THCTensor_(expand)(state, rb, opb, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THCTensor_(expand)(state, rc, opc, sizes, raiseErrors); THAssert(ret == 0); // since we inferred this already, it must be valid - THCTensor_(setStorageNd)(state, - rb, - THCTensor_(storage)(state, opb), - THCTensor_(storageOffset)(state, opb), - THLongStorage_size(sizes), - expandedSizes, - expandedStrides); - THFree(expandedSizes); - THFree(expandedStrides); + THLongStorage_free(sizes); return 0; } diff --git a/lib/THC/generic/THCTensor.h b/lib/THC/generic/THCTensor.h index c1ce52d..a19646e 100644 --- a/lib/THC/generic/THCTensor.h +++ b/lib/THC/generic/THCTensor.h @@ -71,6 +71,7 @@ THC_API THCTensor *THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THL THC_API int THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes, int raiseErrors); THC_API int THCTensor_(expand2)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *opa, THCTensor *opb, int raiseErrors); +THC_API int THCTensor_(expand3)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *rc, THCTensor *opa, THCTensor *opb, THCTensor *opc, int raiseErrors); THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride); THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src); |