diff options
author | soumith <soumith@fb.com> | 2015-07-28 20:41:02 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-07-28 20:41:02 +0300 |
commit | 841227a5dd30d407c254941dd36c5ec17bd83daf (patch) | |
tree | 923052f6a7a9c515a5b0399312cfb5a98eadb7a7 | |
parent | 194522f1ba96432ab19c176e23a0b9b981174770 (diff) |
fixing table modules to return correct number of gradInputs
-rw-r--r-- | CAddTable.lua | 5 | ||||
-rw-r--r-- | CDivTable.lua | 5 | ||||
-rw-r--r-- | CMulTable.lua | 11 | ||||
-rw-r--r-- | CSubTable.lua | 5 |
4 files changed, 25 insertions, 1 deletions
diff --git a/CAddTable.lua b/CAddTable.lua index 05f804c..42e77fd 100644 --- a/CAddTable.lua +++ b/CAddTable.lua @@ -20,5 +20,10 @@ function CAddTable:updateGradInput(input, gradOutput) self.gradInput[i]:resizeAs(input[i]) self.gradInput[i]:copy(gradOutput) end + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + return self.gradInput end diff --git a/CDivTable.lua b/CDivTable.lua index 582b362..bf044c9 100644 --- a/CDivTable.lua +++ b/CDivTable.lua @@ -17,5 +17,10 @@ function CDivTable:updateGradInput(input, gradOutput) self.gradInput[2] = self.gradInput[2] or input[1].new() self.gradInput[1]:resizeAs(input[1]):copy(gradOutput):cdiv(input[2]) self.gradInput[2]:resizeAs(input[2]):zero():addcdiv(-1,self.gradInput[1],input[2]):cmul(input[1]) + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + return self.gradInput end diff --git a/CMulTable.lua b/CMulTable.lua index 0e327be..0689f33 100644 --- a/CMulTable.lua +++ b/CMulTable.lua @@ -23,6 +23,11 @@ function CMulTable:updateGradInput_efficient(input, gradOutput) self.tout:copy(self.output):cdiv(input[i]) self.gradInput[i]:cmul(self.tout) end + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + return self.gradInput end @@ -36,6 +41,10 @@ function CMulTable:updateGradInput(input, gradOutput) end end end + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + return self.gradInput end - diff --git a/CSubTable.lua b/CSubTable.lua index b929e48..eb74920 100644 --- a/CSubTable.lua +++ b/CSubTable.lua @@ -17,5 +17,10 @@ function CSubTable:updateGradInput(input, gradOutput) self.gradInput[2] = self.gradInput[2] or input[1].new() self.gradInput[1]:resizeAs(input[1]):copy(gradOutput) self.gradInput[2]:resizeAs(input[2]):copy(gradOutput):mul(-1) + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + return self.gradInput end |