Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-10 21:33:24 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-11 11:33:54 +0300
commit99dbaecbe142c4eaa133b9d476ed1816d55b91be (patch)
treee5f114974963cb64f89ec6714288a08bccff1706
parent4817e53af421d4921effd9ab11ccd24c44c5eeca (diff)
Support "fused" ops: addcmul/addcdiv.
-rw-r--r--lib/THC/generic/THCTensor.c78
-rw-r--r--lib/THC/generic/THCTensor.h1
2 files changed, 43 insertions, 36 deletions
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c
index ccd6277..57ff028 100644
--- a/lib/THC/generic/THCTensor.c
+++ b/lib/THC/generic/THCTensor.c
@@ -303,12 +303,14 @@ THCTensor* THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THLongStora
int THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes, int raiseErrors)
{
+ THArgCheck(THCTensor_(nDimension)(state, tensor) > 0, 0, "can't expand an empty tensor");
if (raiseErrors) {
THArgCheck(THLongStorage_size(sizes) >= THCTensor_(nDimension)(state, tensor), 1,
"the number of sizes provided must be greater or equal to the "
"number of dimensions in the tensor");
+ } else if (THLongStorage_size(sizes) < THCTensor_(nDimension)(state, tensor)) {
+ return -1;
}
- THArgCheck(THCTensor_(nDimension)(state, tensor) > 0, 0, "can't expand an empty tensor");
long *expandedSizes;
long *expandedStrides;
@@ -343,46 +345,50 @@ int THCTensor_(expand2)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor
if(ret != 0) {
return ret;
}
-
- long *expandedSizes;
- long *expandedStrides;
- ret = THLongStorage_inferExpandGeometry(opa->size,
- opa->stride,
- THCTensor_(nDimension)(state, opa),
- sizes,
- &expandedSizes,
- &expandedStrides,
- raiseErrors);
+ ret = THCTensor_(expand)(state, ra, opa, sizes, raiseErrors);
+ THAssert(ret == 0); // since we inferred this already, it must be valid
+ ret = THCTensor_(expand)(state, rb, opb, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
- THCTensor_(setStorageNd)(state,
- ra,
- THCTensor_(storage)(state, opa),
- THCTensor_(storageOffset)(state, opa),
- THLongStorage_size(sizes),
- expandedSizes,
- expandedStrides);
- THFree(expandedSizes);
- THFree(expandedStrides);
+ THLongStorage_free(sizes);
+ return 0;
+}
+
+THC_API int THCTensor_(expand3)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *rc, THCTensor *opa, THCTensor *opb, THCTensor *opc, int raiseErrors) {
+ THArgCheck(THCTensor_(nDimension)(state, opa) > 0, 0, "can't expand empty tensor opa");
+ THArgCheck(THCTensor_(nDimension)(state, opb) > 0, 0, "can't expand empty tensor opb");
+ THArgCheck(THCTensor_(nDimension)(state, opc) > 0, 0, "can't expand empty tensor opc");
- ret = THLongStorage_inferExpandGeometry(opb->size,
- opb->stride,
- THCTensor_(nDimension)(state, opb),
- sizes,
- &expandedSizes,
- &expandedStrides,
- raiseErrors);
+ const int op_n = 3;
+ long **op_sizes = THAlloc(sizeof(long**)*op_n);
+ long *op_dims = THAlloc(sizeof(long*)*op_n);
+
+ op_sizes[ 0 ] = opa->size;
+ op_sizes[ 1 ] = opb->size;
+ op_sizes[ 2 ] = opc->size;
+ op_dims[ 0 ] = opa->nDimension;
+ op_dims[ 1 ] = opb->nDimension;
+ op_dims[ 2 ] = opc->nDimension;
+
+ THLongStorage *sizes = THLongStorage_new();
+ int ret = THLongStorage_inferSizeN(sizes,
+ op_n,
+ op_sizes,
+ op_dims,
+ raiseErrors);
+
+ if(ret != 0) {
+ return ret;
+ }
+
+ ret = THCTensor_(expand)(state, ra, opa, sizes, raiseErrors);
+ THAssert(ret == 0); // since we inferred this already, it must be valid
+ ret = THCTensor_(expand)(state, rb, opb, sizes, raiseErrors);
+ THAssert(ret == 0); // since we inferred this already, it must be valid
+ ret = THCTensor_(expand)(state, rc, opc, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
- THCTensor_(setStorageNd)(state,
- rb,
- THCTensor_(storage)(state, opb),
- THCTensor_(storageOffset)(state, opb),
- THLongStorage_size(sizes),
- expandedSizes,
- expandedStrides);
- THFree(expandedSizes);
- THFree(expandedStrides);
+ THLongStorage_free(sizes);
return 0;
}
diff --git a/lib/THC/generic/THCTensor.h b/lib/THC/generic/THCTensor.h
index c1ce52d..a19646e 100644
--- a/lib/THC/generic/THCTensor.h
+++ b/lib/THC/generic/THCTensor.h
@@ -71,6 +71,7 @@ THC_API THCTensor *THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THL
THC_API int THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes, int raiseErrors);
THC_API int THCTensor_(expand2)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *opa, THCTensor *opb, int raiseErrors);
+THC_API int THCTensor_(expand3)(THCState *state, THCTensor *ra, THCTensor *rb, THCTensor *rc, THCTensor *opa, THCTensor *opb, THCTensor *opc, int raiseErrors);
THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride);
THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src);