diff options
author | soumith <soumith@gmail.com> | 2015-07-20 10:30:39 +0300 |
---|---|---|
committer | soumith <soumith@gmail.com> | 2015-07-20 10:30:39 +0300 |
commit | 0634fac2079fbd0540f62dbd04bcc64412670829 (patch) | |
tree | 9cefe3584a3a81f38b38555435e05e73cee5e4a8 /JoinTable.lua | |
parent | b87e803d3d51666959ae999bdb126da901360357 (diff) |
fixes jointable backward bug
Diffstat (limited to 'JoinTable.lua')
-rw-r--r-- | JoinTable.lua | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/JoinTable.lua b/JoinTable.lua index 03b4606..c143bd4 100644 --- a/JoinTable.lua +++ b/JoinTable.lua @@ -8,7 +8,7 @@ function JoinTable:__init(dimension, nInputDims) self.nInputDims = nInputDims end -function JoinTable:updateOutput(input) +function JoinTable:updateOutput(input) local dimension = self.dimension if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then dimension = dimension + 1 @@ -21,11 +21,11 @@ function JoinTable:updateOutput(input) else self.size[dimension] = self.size[dimension] + currentOutput:size(dimension) - end + end end self.output:resize(self.size) - - local offset = 1 + + local offset = 1 for i=1,#input do local currentOutput = input[i] self.output:narrow(dimension, offset, @@ -41,16 +41,21 @@ function JoinTable:updateGradInput(input, gradOutput) dimension = dimension + 1 end - for i=1,#input do + for i=1,#input do if self.gradInput[i] == nil then self.gradInput[i] = input[i].new() end self.gradInput[i]:resizeAs(input[i]) end + -- clear out invalid gradInputs + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + local offset = 1 for i=1,#input do - local currentOutput = input[i] + local currentOutput = input[i] local currentGradInput = gradOutput:narrow(dimension, offset, currentOutput:size(dimension)) self.gradInput[i]:copy(currentGradInput) |