diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-10-22 23:33:13 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-10-22 23:33:13 +0400 |
commit | 1d5b59b9f202f64c854d2267b79e97596a3c1d26 (patch) | |
tree | 2d064db8666b373856c9059a3425ecd743d8bf31 | |
parent | b8addac4e7fe7171e42886df7984cff3fca35ef8 (diff) |
Fixed Tensor alloc for some modules (For CUDA)
-rw-r--r-- | CAddTable.lua | 2 | ||||
-rw-r--r-- | CDivTable.lua | 4 | ||||
-rw-r--r-- | CMulTable.lua | 9 | ||||
-rw-r--r-- | CSubTable.lua | 4 |
4 files changed, 10 insertions, 9 deletions
diff --git a/CAddTable.lua b/CAddTable.lua index afe3568..05f804c 100644 --- a/CAddTable.lua +++ b/CAddTable.lua @@ -16,7 +16,7 @@ end function CAddTable:updateGradInput(input, gradOutput) for i=1,#input do - self.gradInput[i] = self.gradInput[i] or torch.Tensor() + self.gradInput[i] = self.gradInput[i] or input[1].new() self.gradInput[i]:resizeAs(input[i]) self.gradInput[i]:copy(gradOutput) end diff --git a/CDivTable.lua b/CDivTable.lua index f91d024..582b362 100644 --- a/CDivTable.lua +++ b/CDivTable.lua @@ -13,8 +13,8 @@ function CDivTable:updateOutput(input) end function CDivTable:updateGradInput(input, gradOutput) - self.gradInput[1] = self.gradInput[1] or torch.Tensor() - self.gradInput[2] = self.gradInput[2] or torch.Tensor() + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[1].new() self.gradInput[1]:resizeAs(input[1]):copy(gradOutput):cdiv(input[2]) self.gradInput[2]:resizeAs(input[2]):zero():addcdiv(-1,self.gradInput[1],input[2]):cmul(input[1]) return self.gradInput diff --git a/CMulTable.lua b/CMulTable.lua index 4c058b6..f82776b 100644 --- a/CMulTable.lua +++ b/CMulTable.lua @@ -15,12 +15,13 @@ function CMulTable:updateOutput(input) end function CMulTable:updateGradInput(input, gradOutput) - local tout = torch.Tensor():resizeAs(self.output) + self.tout = self.tout or input[1].new() + self.tout:resizeAs(self.output) for i=1,#input do - self.gradInput[i] = self.gradInput[i] or torch.Tensor() + self.gradInput[i] = self.gradInput[i] or input[1].new() self.gradInput[i]:resizeAs(input[i]):copy(gradOutput) - tout:copy(self.output):cdiv(input[i]) - self.gradInput[i]:cmul(tout) + self.tout:copy(self.output):cdiv(input[i]) + self.gradInput[i]:cmul(self.tout) end return self.gradInput end diff --git a/CSubTable.lua b/CSubTable.lua index ffc495b..060dee5 100644 --- a/CSubTable.lua +++ b/CSubTable.lua @@ -13,8 +13,8 @@ function CSubTable:updateOutput(input) end function CSubTable:updateGradInput(input, gradOutput) - self.gradInput[1] = self.gradInput[1] or torch.Tensor() - self.gradInput[2] = self.gradInput[2] or torch.Tensor() + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[1].new() self.gradInput[1]:resizeAs(input[1]):copy(gradOutput) self.gradInput[2]:resizeAs(input[1]):copy(gradOutput):mul(-1) return self.gradInput |