Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-05-07 22:06:06 +0300
committerSoumith Chintala <soumith@gmail.com>2015-05-07 22:06:06 +0300
commit28b0d2a80f302e39876002a3978bb4f70c4ee171 (patch)
treee97e5c9aecfa1ad561025b7d84a2c10dfec258a3
parented4653cbe8607c4f28501a401ea6d9a2d71acf23 (diff)
parent3e6a5d49c2d0ca60653ffdae949b41987b30b8d6 (diff)
Merge pull request #261 from torch/convfix
fixing typing in SpatialConvolution
-rw-r--r--SpatialConvolution.lua14
1 files changed, 10 insertions, 4 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 1eff358..ac98d9d 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -19,7 +19,7 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, pa
self.bias = torch.Tensor(nOutputPlane)
self.gradWeight = torch.Tensor(nOutputPlane, nInputPlane, kH, kW)
self.gradBias = torch.Tensor(nOutputPlane)
-
+
self:reset()
end
@@ -35,7 +35,7 @@ function SpatialConvolution:reset(stdv)
end)
self.bias:apply(function()
return torch.uniform(-stdv, stdv)
- end)
+ end)
else
self.weight:uniform(-stdv, stdv)
self.bias:uniform(-stdv, stdv)
@@ -46,10 +46,10 @@ local function backCompatibility(self)
self.finput = self.finput or self.weight.new()
self.fgradInput = self.fgradInput or self.weight.new()
self.padding = self.padding or 0
- if self.weight:dim() == 2 then
+ if self.weight:dim() == 2 then
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
end
- if self.gradWeight and self.gradWeight:dim() == 2 then
+ if self.gradWeight and self.gradWeight:dim() == 2 then
self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
end
end
@@ -109,3 +109,9 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
unviewWeight(self)
return out
end
+
+function SpatialConvolution:type(type)
+ self.finput = torch.Tensor()
+ self.fgradInput = torch.Tensor()
+ return parent.type(self,type)
+end