Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2017-05-21 01:45:25 +0300
committerSoumith Chintala <soumith@gmail.com>2017-05-25 22:39:07 +0300
commited5aa2dceedf2f75c90cde637befe2e0a60e367d (patch)
tree8142dd1049813e56c6b09d8db5861a686c10343d /lib
parentae1a805b370b25fb99703e3c3eccfb7dee97ccd7 (diff)
Fast transposed copy
Diffstat (limited to 'lib')
-rw-r--r--lib/TH/generic/THTensorCopy.c70
1 files changed, 69 insertions, 1 deletions
diff --git a/lib/TH/generic/THTensorCopy.c b/lib/TH/generic/THTensorCopy.c
index e909728..71ccfdd 100644
--- a/lib/TH/generic/THTensorCopy.c
+++ b/lib/TH/generic/THTensorCopy.c
@@ -2,6 +2,70 @@
#define TH_GENERIC_FILE "generic/THTensorCopy.c"
#else
+int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
+ const int MIN_SZ = 60 * 60;
+ return THTensor_(isContiguous)(tensor) &&
+ THTensor_(nDimension)(src) == 2 &&
+ THTensor_(stride)(src, 0) == 1 &&
+ THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
+ THTensor_(nElement)(tensor) >= MIN_SZ;
+}
+
+// special case copy where tensor is contiguous and src is a transposed matrix
+// This can be generalized to most copies, but it's tricker
+void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
+ #define MIN(x, y) (((x) < (y)) ? (x) : (y))
+ #define MAX(x, y) (((x) > (y)) ? (x) : (y))
+
+#ifdef TH_REAL_IS_BYTE
+ const int BLOCK_SZ = 120;
+#else
+ const int BLOCK_SZ = 60;
+#endif
+
+ THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
+ real *sp = THTensor_(data)(src);
+ real *rp = THTensor_(data)(tensor);
+ real *bp = THTensor_(data)(buf);
+
+ long NR = THTensor_(size)(src, 0);
+ long NC = THTensor_(size)(src, 1);
+ for (long R = 0; R < NR; R += BLOCK_SZ) {
+ for (long C = 0; C < NC; C += BLOCK_SZ) {
+ real *spo = sp + R + C * NR;
+ real *rpo = rp + C + R * NC;
+
+ int nr = MIN(NR - R, BLOCK_SZ);
+ int nc = MIN(NC - C, BLOCK_SZ);
+
+ // 1. copy columns from src to buf
+ for (int c = 0; c < nc; c++) {
+ memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(real));
+ }
+
+ // 2. transpose buf in place
+ int rc_max = MAX(nr, nc);
+ int rc_min = MIN(nr, nc);
+ for (int r = 0; r < rc_max; r++) {
+ int end = MIN(r, rc_min);
+ for (int c = 0; c < end; c++) {
+ real tmp = bp[r + BLOCK_SZ * c];
+ bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
+ bp[r * BLOCK_SZ + c] = tmp;
+ }
+ }
+
+ // 3. copy rows from buf to dst
+ for (int r = 0; r < nr; r++) {
+ memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(real));
+ }
+ }
+ }
+ THTensor_(free)(buf);
+ #undef MIN
+ #undef MAX
+}
+
void THTensor_(copy)(THTensor *tensor, THTensor *src)
{
if (THTensor_(isContiguous)(tensor) && THTensor_(isContiguous)(src) && THTensor_(nElement)(tensor) == THTensor_(nElement)(src)) {
@@ -9,10 +73,14 @@ void THTensor_(copy)(THTensor *tensor, THTensor *src)
real *rp = THTensor_(data)(tensor);
ptrdiff_t sz = THTensor_(nElement)(tensor);
#ifndef TH_REAL_IS_HALF
- THVector_(copy)(rp, sp, sz);
+ THVector_(copy)(rp, sp, sz);
#else
memcpy(rp, sp, sz * sizeof(real));
#endif
+#ifndef TH_REAL_IS_HALF
+ } else if (THTensor_(copyTransposeValid)(tensor, src)) {
+ THTensor_(copyTranspose)(tensor, src);
+#endif
} else {
TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;)
}