Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAngela Fan <angelafan@fb.com>2017-01-24 09:51:39 +0300
committerAngela Fan <angelafan@fb.com>2017-01-24 09:51:53 +0300
commit59e08237f3aa26df80a46331f742c28b9a93e752 (patch)
treedf1e8bc9c983781d7d3a91879fb57aab123da787 /Narrow.lua
parent7742eba2218446057f0414c9df4879c5e14481a8 (diff)
changed narrow to standardize negative length and negative offset behavior,
modified documentation to reflect new behavior
Diffstat (limited to 'Narrow.lua')
-rw-r--r--Narrow.lua28
1 files changed, 16 insertions, 12 deletions
diff --git a/Narrow.lua b/Narrow.lua
index a6ebaa3..01be934 100644
--- a/Narrow.lua
+++ b/Narrow.lua
@@ -13,14 +13,16 @@ end
function Narrow:updateOutput(input)
local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension
local length = self.length
- if length < 0 then
- length = input:size(dim) - self.index + self.length + 2
- end
local index = self.index
- if self.index < 0 then
- index = 1
- length = input:size(dim) - length
+
+ if index < 0 then
+ index = input:size(dim) + self.index + 1
end
+
+ if length < 0 then
+ length = input:size(dim) - index + 1 - torch.abs(length)
+ end
+
local output=input:narrow(dim, index, length)
self.output = self.output:typeAs(output)
self.output:resizeAs(output):copy(output)
@@ -30,14 +32,16 @@ end
function Narrow:updateGradInput(input, gradOutput)
local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension
local length = self.length
- if length < 0 then
- length = input:size(dim) - self.index + self.length + 2
- end
local index = self.index
- if self.index < 0 then
- index = 1
- length = input:size(dim) - length
+
+ if index < 0 then
+ index = input:size(dim) + self.index + 1
end
+
+ if length < 0 then
+ length = input:size(dim) - index + 1 - torch.abs(length)
+ end
+
self.gradInput = self.gradInput:typeAs(input)
self.gradInput:resizeAs(input):zero()
self.gradInput:narrow(dim,index,length):copy(gradOutput)