Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-12-09 04:37:36 +0300
committerGitHub <noreply@github.com>2016-12-09 04:37:36 +0300
commitff37975147f8c16e5bd2f92ed62f7f34c3c4a012 (patch)
treec70fee23b4f8e44cce8591a4a6392b2a0505913d
parent8d35db45bbb2ad35d3a045d7ebe185b1f9efc505 (diff)
parent92def6c8ed0234268e0afc6ba8e6ddb679834816 (diff)
Merge pull request #397 from lukeyeager/fix-slice-range
Fix sliceRange for when nElem < splits
-rw-r--r--DataParallelTable.lua12
1 files changed, 7 insertions, 5 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua
index 75e8213..102be72 100644
--- a/DataParallelTable.lua
+++ b/DataParallelTable.lua
@@ -463,12 +463,14 @@ function DataParallelTable:apply(callback)
end
local function sliceRange(nElem, idx, splits)
- local eltsPerMod = nElem / splits
- local rangeStart = math.ceil((idx - 1) * eltsPerMod) + 1
- if idx == splits then
- return rangeStart, nElem - rangeStart + 1
+ local eltsPerMod = math.floor(nElem / splits)
+ local numExtra = nElem - eltsPerMod * splits
+ if idx <= numExtra then
+ rangeStart = (idx - 1) * (eltsPerMod + 1) + 1
+ return rangeStart, eltsPerMod + 1
else
- return rangeStart, math.ceil(idx * eltsPerMod) - rangeStart + 1
+ rangeStart = numExtra * (eltsPerMod + 1) + (idx - 1 - numExtra) * eltsPerMod + 1
+ return rangeStart, eltsPerMod
end
end