Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'SpatialFullConvolution.lua')
-rw-r--r--SpatialFullConvolution.lua44
1 files changed, 41 insertions, 3 deletions
diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua
index dc9d944..121a07e 100644
--- a/SpatialFullConvolution.lua
+++ b/SpatialFullConvolution.lua
@@ -71,7 +71,20 @@ function SpatialFullConvolution:updateOutput(input)
self:backCompatibility()
input = makeContiguous(self, input)
- return input.nn.SpatialFullConvolution_updateOutput(self, input)
+ input.THNN.SpatialFullConvolution_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.weight:cdata(),
+ self.bias:cdata(),
+ self.finput:cdata(),
+ self.fgradInput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.adjW, self.adjH
+ )
+
+ return self.output
end
function SpatialFullConvolution:updateGradInput(input, gradOutput)
@@ -79,15 +92,40 @@ function SpatialFullConvolution:updateGradInput(input, gradOutput)
if self.gradInput then
input, gradOutput = makeContiguous(self, input, gradOutput)
- return input.nn.SpatialFullConvolution_updateGradInput(self, input, gradOutput)
+ input.THNN.SpatialFullConvolution_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.weight:cdata(),
+ self.finput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.adjW, self.adjH
+ )
+
+ return self.gradInput
end
end
function SpatialFullConvolution:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
self:backCompatibility()
input, gradOutput = makeContiguous(self, input, gradOutput)
- return input.nn.SpatialFullConvolution_accGradParameters(self, input, gradOutput, scale)
+ input.THNN.SpatialFullConvolution_accGradParameters(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradWeight:cdata(),
+ self.gradBias:cdata(),
+ self.finput:cdata(),
+ self.fgradInput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.adjW, self.adjH,
+ scale
+ )
end
function SpatialFullConvolution:type(type, tensorCache)