diff options
author | Adam Lerer <alerer@fb.com> | 2017-05-21 01:45:25 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-05-25 22:39:07 +0300 |
commit | ed5aa2dceedf2f75c90cde637befe2e0a60e367d (patch) | |
tree | 8142dd1049813e56c6b09d8db5861a686c10343d /lib | |
parent | ae1a805b370b25fb99703e3c3eccfb7dee97ccd7 (diff) |
Fast transposed copy
Diffstat (limited to 'lib')
-rw-r--r-- | lib/TH/generic/THTensorCopy.c | 70 |
1 files changed, 69 insertions, 1 deletions
diff --git a/lib/TH/generic/THTensorCopy.c b/lib/TH/generic/THTensorCopy.c index e909728..71ccfdd 100644 --- a/lib/TH/generic/THTensorCopy.c +++ b/lib/TH/generic/THTensorCopy.c @@ -2,6 +2,70 @@ #define TH_GENERIC_FILE "generic/THTensorCopy.c" #else +int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) { + const int MIN_SZ = 60 * 60; + return THTensor_(isContiguous)(tensor) && + THTensor_(nDimension)(src) == 2 && + THTensor_(stride)(src, 0) == 1 && + THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) && + THTensor_(nElement)(tensor) >= MIN_SZ; +} + +// special case copy where tensor is contiguous and src is a transposed matrix +// This can be generalized to most copies, but it's tricker +void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) { + #define MIN(x, y) (((x) < (y)) ? (x) : (y)) + #define MAX(x, y) (((x) > (y)) ? (x) : (y)) + +#ifdef TH_REAL_IS_BYTE + const int BLOCK_SZ = 120; +#else + const int BLOCK_SZ = 60; +#endif + + THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ); + real *sp = THTensor_(data)(src); + real *rp = THTensor_(data)(tensor); + real *bp = THTensor_(data)(buf); + + long NR = THTensor_(size)(src, 0); + long NC = THTensor_(size)(src, 1); + for (long R = 0; R < NR; R += BLOCK_SZ) { + for (long C = 0; C < NC; C += BLOCK_SZ) { + real *spo = sp + R + C * NR; + real *rpo = rp + C + R * NC; + + int nr = MIN(NR - R, BLOCK_SZ); + int nc = MIN(NC - C, BLOCK_SZ); + + // 1. copy columns from src to buf + for (int c = 0; c < nc; c++) { + memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(real)); + } + + // 2. transpose buf in place + int rc_max = MAX(nr, nc); + int rc_min = MIN(nr, nc); + for (int r = 0; r < rc_max; r++) { + int end = MIN(r, rc_min); + for (int c = 0; c < end; c++) { + real tmp = bp[r + BLOCK_SZ * c]; + bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c]; + bp[r * BLOCK_SZ + c] = tmp; + } + } + + // 3. copy rows from buf to dst + for (int r = 0; r < nr; r++) { + memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(real)); + } + } + } + THTensor_(free)(buf); + #undef MIN + #undef MAX +} + void THTensor_(copy)(THTensor *tensor, THTensor *src) { if (THTensor_(isContiguous)(tensor) && THTensor_(isContiguous)(src) && THTensor_(nElement)(tensor) == THTensor_(nElement)(src)) { @@ -9,10 +73,14 @@ void THTensor_(copy)(THTensor *tensor, THTensor *src) real *rp = THTensor_(data)(tensor); ptrdiff_t sz = THTensor_(nElement)(tensor); #ifndef TH_REAL_IS_HALF - THVector_(copy)(rp, sp, sz); + THVector_(copy)(rp, sp, sz); #else memcpy(rp, sp, sz * sizeof(real)); #endif +#ifndef TH_REAL_IS_HALF + } else if (THTensor_(copyTransposeValid)(tensor, src)) { + THTensor_(copyTranspose)(tensor, src); +#endif } else { TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;) } |