Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-30 21:10:16 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-11 11:32:08 +0300
commitf622e9ae0e01c89c9ff01ab79807fb49de4c31e0 (patch)
treeeef03540869a4a923d69773b4a599c8fcea046a0
parent96b4c256818398ce8275ccb1a8f584dda6b504ec (diff)
Add broadcasting support for copy_, simplify code generation by moving a lot of currently generated code to expand_utils.
-rw-r--r--lib/TH/THSize.c15
-rw-r--r--lib/TH/THSize.h5
-rw-r--r--lib/TH/generic/THTensor.c2
3 files changed, 21 insertions, 1 deletions
diff --git a/lib/TH/THSize.c b/lib/TH/THSize.c
index 7d680f3..ccf1f61 100644
--- a/lib/TH/THSize.c
+++ b/lib/TH/THSize.c
@@ -1,3 +1,5 @@
+#include "THSize.h"
+
int THSize_isSameSizeAs(const long *sizeA, long dimsA, const long *sizeB, long dimsB) {
int d;
if (dimsA != dimsB)
@@ -9,3 +11,16 @@ int THSize_isSameSizeAs(const long *sizeA, long dimsA, const long *sizeB, long d
}
return 1;
}
+
+ptrdiff_t THSize_nElement(long dims, long *size) {
+ if(dims == 0)
+ return 0;
+ else
+ {
+ ptrdiff_t nElement = 1;
+ int d;
+ for(d = 0; d < dims; d++)
+ nElement *= size[d];
+ return nElement;
+ }
+}
diff --git a/lib/TH/THSize.h b/lib/TH/THSize.h
index e582977..3d39696 100644
--- a/lib/TH/THSize.h
+++ b/lib/TH/THSize.h
@@ -2,7 +2,12 @@
#define TH_SIZE_INC
#include "THGeneral.h"
+#include <stddef.h>
+
+// THTensor functions that would work on a THSize if we had such a class in C++,
+// i.e. THTensor functions that depend only on the shape of the tensor, not the type.
TH_API int THSize_isSameSizeAs(const long *sizeA, long dimsA, const long *sizeB, long dimsB);
+TH_API ptrdiff_t THSize_nElement(long dims, long *size);
#endif
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c
index 2331cdd..61363ac 100644
--- a/lib/TH/generic/THTensor.c
+++ b/lib/TH/generic/THTensor.c
@@ -292,8 +292,8 @@ THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raise
}
int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes, int raiseErrors) {
- THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor");
if (raiseErrors) {
+ THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor");
THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1,
"the number of sizes provided must be greater or equal to the "
"number of dimensions in the tensor");