Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-05-14 21:42:11 +0300
committerNicholas Leonard <nick@nikopia.org>2015-05-14 22:26:49 +0300
commitdaf9b6caca19906059e05ae57d7dfd3be0c0d6a6 (patch)
treec2982b27d0568fcb6d70a2dd9d097e2e63a80b5c /SpatialConvolution.lua
parent28bb486ea0ee94f6007d781a01fe9c6e90fca20b (diff)
SpatialConvolution empty gradParams fix
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r--SpatialConvolution.lua8
1 files changed, 6 insertions, 2 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index ac98d9d..4a42b63 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -73,12 +73,16 @@ end
-- function to re-view the weight layout in a way that would make the MM ops happy
local function viewWeight(self)
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
- self.gradWeight = self.gradWeight and self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
+ if self.gradWeight and self.gradWeight:dim() > 0 then
+ self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
+ end
end
local function unviewWeight(self)
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
- self.gradWeight = self.gradWeight and self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
+ if self.gradWeight and self.gradWeight:dim() > 0 then
+ self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
+ end
end
function SpatialConvolution:updateOutput(input)