Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjokeren <robinho364@gmail.com>2017-01-26 19:06:53 +0300
committerSoumith Chintala <soumith@gmail.com>2017-02-23 14:01:09 +0300
commit8caab28c25d03d7787ecc8347b8ab7db2b24262c (patch)
tree11270095736eacdb9d47a7bd155cae7e0bde7d30
parent0153e4d052923c589a8755d60e26dd001492847a (diff)
Add isTransposed judge and enable multithread of fill functions
-rw-r--r--lib/TH/generic/THTensor.c24
-rw-r--r--lib/TH/generic/THTensorMath.c24
2 files changed, 44 insertions, 4 deletions
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c
index 238da61..e8f1e18 100644
--- a/lib/TH/generic/THTensor.c
+++ b/lib/TH/generic/THTensor.c
@@ -546,6 +546,30 @@ void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension)
self->size[dimension] = 1;
}
+int THTensor_(isTransposed)(const THTensor *self)
+{
+ if (THTensor_(isContiguous)(self)) {
+ return 0;
+ }
+ long max_stride = 1;
+ long size_max_stride = 1;
+ long z = 1;
+ int d;
+ for (d = 0; d < self->nDimension; ++d) {
+ if (self->stride[d] == 0 && self->size[d] != 1)
+ return 0;
+ if (self->stride[d] > max_stride) {
+ max_stride = self->stride[d];
+ size_max_stride = self->size[d];
+ }
+ z *= self->size[d];
+ }
+ if (z == max_stride * size_max_stride) {
+ return 1;
+ }
+ return 0;
+}
+
int THTensor_(isContiguous)(const THTensor *self)
{
long z = 1;
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index dd19ce2..2c5813a 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -10,14 +10,30 @@
void THTensor_(fill)(THTensor *r_, real value)
{
- TH_TENSOR_APPLY(real, r_,
- THVector_(fill)(r__data, value, r__size); break;);
+ if (THTensor_(isContiguous)(r_) || THTensor_(isTransposed)(r_)) {
+ real *rp = THTensor_(data)(r_);
+ ptrdiff_t sz = THTensor_(nElement)(r_);
+ #pragma omp parallel if(sz > TH_OMP_OVERHEAD_THRESHOLD)
+ {
+ #ifdef _OPENMP
+ size_t num_threads = omp_get_num_threads();
+ size_t tid = omp_get_thread_num();
+ #else
+ size_t num_threads = 1;
+ size_t tid = 0;
+ #endif
+ ptrdiff_t i = tid * (sz / num_threads);
+ ptrdiff_t i_end = tid == num_threads - 1 ? sz : i + sz / num_threads;
+ THVector_(fill)(rp+i, value, i_end-i);
+ }
+ } else {
+ TH_TENSOR_APPLY(real, r_, THVector_(fill)(r__data, value, r__size); break;);
+ }
}
void THTensor_(zero)(THTensor *r_)
{
- TH_TENSOR_APPLY(real, r_,
- THVector_(fill)(r__data, 0, r__size); break;);
+ THTensor_(fill)(r_, 0);
}
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value)