Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2015-07-20 10:30:39 +0300
committersoumith <soumith@gmail.com>2015-07-20 10:30:39 +0300
commit0634fac2079fbd0540f62dbd04bcc64412670829 (patch)
tree9cefe3584a3a81f38b38555435e05e73cee5e4a8 /JoinTable.lua
parentb87e803d3d51666959ae999bdb126da901360357 (diff)
fixes jointable backward bug
Diffstat (limited to 'JoinTable.lua')
-rw-r--r--JoinTable.lua17
1 files changed, 11 insertions, 6 deletions
diff --git a/JoinTable.lua b/JoinTable.lua
index 03b4606..c143bd4 100644
--- a/JoinTable.lua
+++ b/JoinTable.lua
@@ -8,7 +8,7 @@ function JoinTable:__init(dimension, nInputDims)
self.nInputDims = nInputDims
end
-function JoinTable:updateOutput(input)
+function JoinTable:updateOutput(input)
local dimension = self.dimension
if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then
dimension = dimension + 1
@@ -21,11 +21,11 @@ function JoinTable:updateOutput(input)
else
self.size[dimension] = self.size[dimension]
+ currentOutput:size(dimension)
- end
+ end
end
self.output:resize(self.size)
-
- local offset = 1
+
+ local offset = 1
for i=1,#input do
local currentOutput = input[i]
self.output:narrow(dimension, offset,
@@ -41,16 +41,21 @@ function JoinTable:updateGradInput(input, gradOutput)
dimension = dimension + 1
end
- for i=1,#input do
+ for i=1,#input do
if self.gradInput[i] == nil then
self.gradInput[i] = input[i].new()
end
self.gradInput[i]:resizeAs(input[i])
end
+ -- clear out invalid gradInputs
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
local offset = 1
for i=1,#input do
- local currentOutput = input[i]
+ local currentOutput = input[i]
local currentGradInput = gradOutput:narrow(dimension, offset,
currentOutput:size(dimension))
self.gradInput[i]:copy(currentGradInput)