diff options
author | Angela Fan <angelafan@fb.com> | 2017-01-24 09:51:39 +0300 |
---|---|---|
committer | Angela Fan <angelafan@fb.com> | 2017-01-24 09:51:53 +0300 |
commit | 59e08237f3aa26df80a46331f742c28b9a93e752 (patch) | |
tree | df1e8bc9c983781d7d3a91879fb57aab123da787 /Narrow.lua | |
parent | 7742eba2218446057f0414c9df4879c5e14481a8 (diff) |
changed narrow to standardize negative length and negative offset behavior,
modified documentation to reflect new behavior
Diffstat (limited to 'Narrow.lua')
-rw-r--r-- | Narrow.lua | 28 |
1 files changed, 16 insertions, 12 deletions
@@ -13,14 +13,16 @@ end function Narrow:updateOutput(input) local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension local length = self.length - if length < 0 then - length = input:size(dim) - self.index + self.length + 2 - end local index = self.index - if self.index < 0 then - index = 1 - length = input:size(dim) - length + + if index < 0 then + index = input:size(dim) + self.index + 1 end + + if length < 0 then + length = input:size(dim) - index + 1 - torch.abs(length) + end + local output=input:narrow(dim, index, length) self.output = self.output:typeAs(output) self.output:resizeAs(output):copy(output) @@ -30,14 +32,16 @@ end function Narrow:updateGradInput(input, gradOutput) local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension local length = self.length - if length < 0 then - length = input:size(dim) - self.index + self.length + 2 - end local index = self.index - if self.index < 0 then - index = 1 - length = input:size(dim) - length + + if index < 0 then + index = input:size(dim) + self.index + 1 end + + if length < 0 then + length = input:size(dim) - index + 1 - torch.abs(length) + end + self.gradInput = self.gradInput:typeAs(input) self.gradInput:resizeAs(input):zero() self.gradInput:narrow(dim,index,length):copy(gradOutput) |