diff options
Diffstat (limited to 'lib/THC/generic/THCTensor.c')
-rw-r--r-- | lib/THC/generic/THCTensor.c | 78 |
1 files changed, 42 insertions, 36 deletions
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c index ccd6277..57ff028 100644 --- a/lib/THC/generic/THCTensor.c +++ b/lib/THC/generic/THCTensor.c @@ -303,12 +303,14 @@ THCTensor* THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THLongStora int THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes, int raiseErrors) { + THArgCheck(THCTensor_(nDimension)(state, tensor) > 0, 0, "can't expand an empty tensor"); if (raiseErrors) { THArgCheck(THLongStorage_size(sizes) >= THCTensor_(nDimension)(state, tensor), 1, "the number of sizes provided must be greater or equal to the " "number of dimensions in the tensor"); + } else if (THLongStorage_size(sizes) < THCTensor_(nDimension)(state, tensor)) { + return -1; } - THArgCheck(THCTensor_(nDimension)(state, tensor) > 0, 0, "can't expand an empty tensor"); long *expandedSizes; long *expandedStrides; @@ -343,46 +345,50 @@ int THCTensor_(expand2)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor if(ret != 0) { return ret; } - - long *expandedSizes; - long *expandedStrides; - ret = THLongStorage_inferExpandGeometry(opa->size, - opa->stride, - THCTensor_(nDimension)(state, opa), - sizes, - &expandedSizes, - &expandedStrides, - raiseErrors); + ret = THCTensor_(expand)(state, ra, opa, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THCTensor_(expand)(state, rb, opb, sizes, raiseErrors); THAssert(ret == 0); // since we inferred this already, it must be valid - THCTensor_(setStorageNd)(state, - ra, - THCTensor_(storage)(state, opa), - THCTensor_(storageOffset)(state, opa), - THLongStorage_size(sizes), - expandedSizes, - expandedStrides); - THFree(expandedSizes); - THFree(expandedStrides); + THLongStorage_free(sizes); + return 0; +} + +THC_API int THCTensor_(expand3)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *rc, THCTensor *opa, THCTensor *opb, THCTensor *opc, int raiseErrors) { + THArgCheck(THCTensor_(nDimension)(state, opa) > 0, 0, "can't expand empty tensor opa"); + THArgCheck(THCTensor_(nDimension)(state, opb) > 0, 0, "can't expand empty tensor opb"); + THArgCheck(THCTensor_(nDimension)(state, opc) > 0, 0, "can't expand empty tensor opc"); - ret = THLongStorage_inferExpandGeometry(opb->size, - opb->stride, - THCTensor_(nDimension)(state, opb), - sizes, - &expandedSizes, - &expandedStrides, - raiseErrors); + const int op_n = 3; + long **op_sizes = THAlloc(sizeof(long**)*op_n); + long *op_dims = THAlloc(sizeof(long*)*op_n); + + op_sizes[ 0 ] = opa->size; + op_sizes[ 1 ] = opb->size; + op_sizes[ 2 ] = opc->size; + op_dims[ 0 ] = opa->nDimension; + op_dims[ 1 ] = opb->nDimension; + op_dims[ 2 ] = opc->nDimension; + + THLongStorage *sizes = THLongStorage_new(); + int ret = THLongStorage_inferSizeN(sizes, + op_n, + op_sizes, + op_dims, + raiseErrors); + + if(ret != 0) { + return ret; + } + + ret = THCTensor_(expand)(state, ra, opa, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THCTensor_(expand)(state, rb, opb, sizes, raiseErrors); + THAssert(ret == 0); // since we inferred this already, it must be valid + ret = THCTensor_(expand)(state, rc, opc, sizes, raiseErrors); THAssert(ret == 0); // since we inferred this already, it must be valid - THCTensor_(setStorageNd)(state, - rb, - THCTensor_(storage)(state, opb), - THCTensor_(storageOffset)(state, opb), - THLongStorage_size(sizes), - expandedSizes, - expandedStrides); - THFree(expandedSizes); - THFree(expandedStrides); + THLongStorage_free(sizes); return 0; } |