diff options
author | Conrado Miranda <cmiranda@twitter.com> | 2016-12-17 02:16:13 +0300 |
---|---|---|
committer | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2017-02-10 00:28:36 +0300 |
commit | 90ba83121de74216fc41356ddfe57e3baa43ed5e (patch) | |
tree | c8033d604b3a0e98dfdc269daa4303f067a322bc | |
parent | 1a3380f09e79e783c069df33d8ec1aab1e8bad51 (diff) |
Added shift operations.
-rw-r--r-- | TensorMath.lua | 12 | ||||
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 68 | ||||
-rw-r--r-- | lib/TH/generic/THTensorMath.h | 2 | ||||
-rw-r--r-- | test/test.lua | 68 |
4 files changed, 150 insertions, 0 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index d816740..048250a 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -311,6 +311,18 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", {name=Tensor, method={default=1}}, {name=real}}) + wrap("lsh", + cname("lsh"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=Tensor, method={default=1}}, + {name=real}}) + + wrap("rsh", + cname("rsh"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=Tensor, method={default=1}}, + {name=real}}) + wrap("fmod", cname("fmod"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 59d6239..4f7f5f9 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -513,6 +513,74 @@ void THTensor_(div)(THTensor *r_, THTensor *t, real value) } } +void THTensor_(lsh)(THTensor *r_, THTensor *t, real value) +{ +#if defined(TH_REAL_IS_FLOAT) + const real temp = powf(2, value); +#elif defined(TH_REAL_IS_DOUBLE) + const real temp = pow(2, value); +#endif + THTensor_(resizeAs)(r_, t); + if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) { + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + long sz = THTensor_(nElement)(t); + long i; + #pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; i<sz; i++) { +#if defined(TH_REAL_IS_FLOAT) + rp[i] = tp[i] * temp; +#elif defined(TH_REAL_IS_DOUBLE) + rp[i] = tp[i] * temp; +#else + rp[i] = ((unsigned real) tp[i]) << value; +#endif + } + } else { +#if defined(TH_REAL_IS_FLOAT) + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data * temp);); +#elif defined(TH_REAL_IS_DOUBLE) + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data * temp);); +#else + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((unsigned real) *t_data) << value);); +#endif + } +} + +void THTensor_(rsh)(THTensor *r_, THTensor *t, real value) +{ +#if defined(TH_REAL_IS_FLOAT) + const real temp = powf(2, value); +#elif defined(TH_REAL_IS_DOUBLE) + const real temp = pow(2, value); +#endif + THTensor_(resizeAs)(r_, t); + if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) { + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + long sz = THTensor_(nElement)(t); + long i; + #pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; i<sz; i++) { +#if defined(TH_REAL_IS_FLOAT) + rp[i] = tp[i] / temp; +#elif defined(TH_REAL_IS_DOUBLE) + rp[i] = tp[i] / temp; +#else + rp[i] = ((unsigned real) tp[i]) >> value; +#endif + } + } else { +#if defined(TH_REAL_IS_FLOAT) + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data / temp);); +#elif defined(TH_REAL_IS_DOUBLE) + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data / temp);); +#else + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((unsigned real) *t_data) >> value);); +#endif + } +} + void THTensor_(fmod)(THTensor *r_, THTensor *t, real value) { THTensor_(resizeAs)(r_, t); diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index c656dfd..294fb32 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -34,6 +34,8 @@ TH_API void THTensor_(add)(THTensor *r_, THTensor *t, real value); TH_API void THTensor_(sub)(THTensor *self, THTensor *src, real value); TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, real value); TH_API void THTensor_(div)(THTensor *r_, THTensor *t, real value); +TH_API void THTensor_(lsh)(THTensor *r_, THTensor *t, real value); +TH_API void THTensor_(rsh)(THTensor *r_, THTensor *t, real value); TH_API void THTensor_(fmod)(THTensor *r_, THTensor *t, real value); TH_API void THTensor_(remainder)(THTensor *r_, THTensor *t, real value); TH_API void THTensor_(clamp)(THTensor *r_, THTensor *t, real min_value, real max_value); diff --git a/test/test.lua b/test/test.lua index e7e26e4..2fe049c 100644 --- a/test/test.lua +++ b/test/test.lua @@ -730,6 +730,74 @@ function torchtest.div() mytester:assertlt(err, precision, 'error in torch.div - scalar, non contiguous') end +function torchtest.lsh() + local m1 = torch.LongTensor(10,10):random(0,100) + local res1 = m1:clone() + + local q = 2 + res1[{ {},3 }]:lsh(q) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ i,3 }] = res2[{ i,3 }] * 4 + end + + local err = (res1-res2):abs():max() + + mytester:assertlt(err, precision, 'error in torch.lsh - scalar, non contiguous') + + local m1 = torch.LongTensor(10,10):random(0,100) + local res1 = m1:clone() + + local q = 2 + res1:lsh(q) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + for j = 1,m1:size(1) do + res2[{ i,j }] = res2[{ i,j }] * 4 + end + end + + local err = (res1-res2):abs():max() + + mytester:assertlt(err, precision, 'error in torch.lsh - scalar, contiguous') +end + +function torchtest.rsh() + local m1 = torch.LongTensor(10,10):random(0,100) + local res1 = m1:clone() + + local q = 2 + res1[{ {},3 }]:rsh(q) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ i,3 }] = math.floor(res2[{ i,3 }] / 4) + end + + local err = (res1-res2):abs():max() + + mytester:assertlt(err, precision, 'error in torch.lsh - scalar, non contiguous') + + local m1 = torch.LongTensor(10,10):random(0,100) + local res1 = m1:clone() + + local q = 2 + res1:rsh(q) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + for j = 1,m1:size(1) do + res2[{ i,j }] = math.floor(res2[{ i,j }] / 4) + end + end + + local err = (res1-res2):abs():max() + + mytester:assertlt(err, precision, 'error in torch.rsh - scalar, contiguous') +end + function torchtest.fmod() local m1 = torch.Tensor(10,10):uniform(-10, 10) local res1 = m1:clone() |