Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-28 20:42:24 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-28 20:42:24 +0300
commit0c2416f347159248048537d8540cc53429d742b0 (patch)
tree923052f6a7a9c515a5b0399312cfb5a98eadb7a7
parent194522f1ba96432ab19c176e23a0b9b981174770 (diff)
parent841227a5dd30d407c254941dd36c5ec17bd83daf (diff)
Merge pull request #333 from torch/tablefixgrad
fixing table modules to return correct number of gradInputs
-rw-r--r--CAddTable.lua5
-rw-r--r--CDivTable.lua5
-rw-r--r--CMulTable.lua11
-rw-r--r--CSubTable.lua5
4 files changed, 25 insertions, 1 deletions
diff --git a/CAddTable.lua b/CAddTable.lua
index 05f804c..42e77fd 100644
--- a/CAddTable.lua
+++ b/CAddTable.lua
@@ -20,5 +20,10 @@ function CAddTable:updateGradInput(input, gradOutput)
self.gradInput[i]:resizeAs(input[i])
self.gradInput[i]:copy(gradOutput)
end
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
return self.gradInput
end
diff --git a/CDivTable.lua b/CDivTable.lua
index 582b362..bf044c9 100644
--- a/CDivTable.lua
+++ b/CDivTable.lua
@@ -17,5 +17,10 @@ function CDivTable:updateGradInput(input, gradOutput)
self.gradInput[2] = self.gradInput[2] or input[1].new()
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput):cdiv(input[2])
self.gradInput[2]:resizeAs(input[2]):zero():addcdiv(-1,self.gradInput[1],input[2]):cmul(input[1])
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
return self.gradInput
end
diff --git a/CMulTable.lua b/CMulTable.lua
index 0e327be..0689f33 100644
--- a/CMulTable.lua
+++ b/CMulTable.lua
@@ -23,6 +23,11 @@ function CMulTable:updateGradInput_efficient(input, gradOutput)
self.tout:copy(self.output):cdiv(input[i])
self.gradInput[i]:cmul(self.tout)
end
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
return self.gradInput
end
@@ -36,6 +41,10 @@ function CMulTable:updateGradInput(input, gradOutput)
end
end
end
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
return self.gradInput
end
-
diff --git a/CSubTable.lua b/CSubTable.lua
index b929e48..eb74920 100644
--- a/CSubTable.lua
+++ b/CSubTable.lua
@@ -17,5 +17,10 @@ function CSubTable:updateGradInput(input, gradOutput)
self.gradInput[2] = self.gradInput[2] or input[1].new()
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput)
self.gradInput[2]:resizeAs(input[2]):copy(gradOutput):mul(-1)
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
return self.gradInput
end