diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-09 04:37:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-09 04:37:36 +0300 |
commit | ff37975147f8c16e5bd2f92ed62f7f34c3c4a012 (patch) | |
tree | c70fee23b4f8e44cce8591a4a6392b2a0505913d | |
parent | 8d35db45bbb2ad35d3a045d7ebe185b1f9efc505 (diff) | |
parent | 92def6c8ed0234268e0afc6ba8e6ddb679834816 (diff) |
Merge pull request #397 from lukeyeager/fix-slice-range
Fix sliceRange for when nElem < splits
-rw-r--r-- | DataParallelTable.lua | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua index 75e8213..102be72 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -463,12 +463,14 @@ function DataParallelTable:apply(callback) end local function sliceRange(nElem, idx, splits) - local eltsPerMod = nElem / splits - local rangeStart = math.ceil((idx - 1) * eltsPerMod) + 1 - if idx == splits then - return rangeStart, nElem - rangeStart + 1 + local eltsPerMod = math.floor(nElem / splits) + local numExtra = nElem - eltsPerMod * splits + if idx <= numExtra then + rangeStart = (idx - 1) * (eltsPerMod + 1) + 1 + return rangeStart, eltsPerMod + 1 else - return rangeStart, math.ceil(idx * eltsPerMod) - rangeStart + 1 + rangeStart = numExtra * (eltsPerMod + 1) + (idx - 1 - numExtra) * eltsPerMod + 1 + return rangeStart, eltsPerMod end end |