diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-05-30 21:10:16 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-11 11:32:08 +0300 |
commit | f622e9ae0e01c89c9ff01ab79807fb49de4c31e0 (patch) | |
tree | eef03540869a4a923d69773b4a599c8fcea046a0 | |
parent | 96b4c256818398ce8275ccb1a8f584dda6b504ec (diff) |
Add broadcasting support for copy_, simplify code generation by moving a lot of currently generated code to expand_utils.
-rw-r--r-- | lib/TH/THSize.c | 15 | ||||
-rw-r--r-- | lib/TH/THSize.h | 5 | ||||
-rw-r--r-- | lib/TH/generic/THTensor.c | 2 |
3 files changed, 21 insertions, 1 deletions
diff --git a/lib/TH/THSize.c b/lib/TH/THSize.c index 7d680f3..ccf1f61 100644 --- a/lib/TH/THSize.c +++ b/lib/TH/THSize.c @@ -1,3 +1,5 @@ +#include "THSize.h" + int THSize_isSameSizeAs(const long *sizeA, long dimsA, const long *sizeB, long dimsB) { int d; if (dimsA != dimsB) @@ -9,3 +11,16 @@ int THSize_isSameSizeAs(const long *sizeA, long dimsA, const long *sizeB, long d } return 1; } + +ptrdiff_t THSize_nElement(long dims, long *size) { + if(dims == 0) + return 0; + else + { + ptrdiff_t nElement = 1; + int d; + for(d = 0; d < dims; d++) + nElement *= size[d]; + return nElement; + } +} diff --git a/lib/TH/THSize.h b/lib/TH/THSize.h index e582977..3d39696 100644 --- a/lib/TH/THSize.h +++ b/lib/TH/THSize.h @@ -2,7 +2,12 @@ #define TH_SIZE_INC #include "THGeneral.h" +#include <stddef.h> + +// THTensor functions that would work on a THSize if we had such a class in C++, +// i.e. THTensor functions that depend only on the shape of the tensor, not the type. TH_API int THSize_isSameSizeAs(const long *sizeA, long dimsA, const long *sizeB, long dimsB); +TH_API ptrdiff_t THSize_nElement(long dims, long *size); #endif diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index 2331cdd..61363ac 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -292,8 +292,8 @@ THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raise } int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes, int raiseErrors) { - THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor"); if (raiseErrors) { + THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor"); THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1, "the number of sizes provided must be greater or equal to the " "number of dimensions in the tensor"); |