Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-10 21:33:24 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-11 11:32:08 +0300
commit26629b188ad6c7c9a175f9f6115125948a765eb0 (patch)
tree3260944a93d32c4fb7ccf9c363d52e85bff06e29
parentd34673d67bc8cefd7055edf55a2c88d28b3ab158 (diff)
Support "fused" ops: addcmul/addcdiv.
-rw-r--r--lib/TH/THStorage.c7
-rw-r--r--lib/TH/generic/THTensor.c69
-rw-r--r--lib/TH/generic/THTensor.h1
3 files changed, 47 insertions, 30 deletions
diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c
index b7d4a5d..54ecc84 100644
--- a/lib/TH/THStorage.c
+++ b/lib/TH/THStorage.c
@@ -106,14 +106,14 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di
TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors) {
THArgCheck(n > 0, 2, "n must be greater than 0");
- THArgCheck(sizes != NULL, 1, "sizesA must not be null");
+ THArgCheck(sizes != NULL, 1, "sizes must not be null");
THArgCheck(dims != NULL, 1, "dims must not be null");
ptrdiff_t ndim = 0;
for (int j = 0; j < n; ++j) {
THArgCheck(sizes[ j ] != NULL, 1, "size %d must not be null", j);
THArgCheck(dims[ j ], 1, "Can't expand empty tensor %d", j);
- ptrdiff_t ndim = dims[ j ] > ndim ? dims[ j ] : ndim;
+ ndim = dims[ j ] > ndim ? dims[ j ] : ndim;
}
long *expandedSizes = THAlloc(sizeof(long)*ndim);
@@ -123,7 +123,7 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes,
long offset = ndim - 1 - i;
for (int j = 0; j < n; ++j) {
long dim = dims[ j ] - 1 - offset;
- long size = (dim >= 0) ? sizes[ i ][ dim ] : 1;
+ long size = (dim >= 0) ? sizes[ j ][ dim ] : 1;
if (size != max_dim_size) {
if (max_dim_size == 1){
max_dim_size = size;
@@ -139,6 +139,7 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes,
}
}
}
+
expandedSizes[ i ] = max_dim_size;
}
THLongStorage_resize(output, ndim);
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c
index 4b5c302..38bbffd 100644
--- a/lib/TH/generic/THTensor.c
+++ b/lib/TH/generic/THTensor.c
@@ -292,12 +292,14 @@ THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raise
}
int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor");
if (raiseErrors) {
THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1,
"the number of sizes provided must be greater or equal to the "
"number of dimensions in the tensor");
+ } else if (THLongStorage_size(sizes) < THTensor_(nDimension)(tensor)) {
+ return -1;
}
- THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor");
long *expandedSizes;
long *expandedStrides;
@@ -326,41 +328,54 @@ int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb,
return ret;
}
- long *expandedSizes;
- long *expandedStrides;
- ret = THLongStorage_inferExpandGeometry(opa->size, opa->stride,
- THTensor_(nDimension)(opa), sizes,
- &expandedSizes, &expandedStrides,
- raiseErrors);
+ ret = THTensor_(expand)(ra, opa, sizes, raiseErrors);
+ THAssert(ret == 0); // since we inferred this already, it must be valid
+ ret = THTensor_(expand)(rb, opb, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
- THTensor_(setStorageNd)(ra,
- THTensor_(storage)(opa),
- THTensor_(storageOffset)(opa),
- THLongStorage_size(sizes),
- expandedSizes,
- expandedStrides);
- THFree(expandedSizes);
- THFree(expandedStrides);
+ THLongStorage_free(sizes);
+ return 0;
+}
+
+TH_API int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc, int raiseErrors) {
+ THArgCheck(THTensor_(nDimension)(opa) > 0, 0, "can't expand empty tensor opa");
+ THArgCheck(THTensor_(nDimension)(opb) > 0, 0, "can't expand empty tensor opb");
+ THArgCheck(THTensor_(nDimension)(opc) > 0, 0, "can't expand empty tensor opc");
+
+ const int op_n = 3;
+ long **op_sizes = THAlloc(sizeof(long**)*op_n);
+ long *op_dims = THAlloc(sizeof(long*)*op_n);
- ret = THLongStorage_inferExpandGeometry(opb->size, opb->stride,
- THTensor_(nDimension)(opb), sizes,
- &expandedSizes, &expandedStrides,
- raiseErrors);
+ op_sizes[ 0 ] = opa->size;
+ op_sizes[ 1 ] = opb->size;
+ op_sizes[ 2 ] = opc->size;
+ op_dims[ 0 ] = opa->nDimension;
+ op_dims[ 1 ] = opb->nDimension;
+ op_dims[ 2 ] = opc->nDimension;
+
+ THLongStorage *sizes = THLongStorage_new();
+ int ret = THLongStorage_inferSizeN(sizes,
+ op_n,
+ op_sizes,
+ op_dims,
+ raiseErrors);
+
+ if(ret != 0) {
+ return ret;
+ }
+
+ ret = THTensor_(expand)(ra, opa, sizes, raiseErrors);
+ THAssert(ret == 0); // since we inferred this already, it must be valid
+ ret = THTensor_(expand)(rb, opb, sizes, raiseErrors);
+ THAssert(ret == 0); // since we inferred this already, it must be valid
+ ret = THTensor_(expand)(rc, opc, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
- THTensor_(setStorageNd)(rb,
- THTensor_(storage)(opb),
- THTensor_(storageOffset)(opb),
- THLongStorage_size(sizes),
- expandedSizes,
- expandedStrides);
- THFree(expandedSizes);
- THFree(expandedStrides);
THLongStorage_free(sizes);
return 0;
}
+
void THTensor_(set)(THTensor *self, THTensor *src)
{
if(self != src)
diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h
index bafe00a..bc33ed5 100644
--- a/lib/TH/generic/THTensor.h
+++ b/lib/TH/generic/THTensor.h
@@ -73,6 +73,7 @@ TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size, int
TH_API int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size, int raiseErrors);
TH_API int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors);
+TH_API int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc, int raiseErrors);
TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);