Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-08-03 13:13:29 +0300
committersoumith <soumith@fb.com>2015-08-03 13:13:29 +0300
commit7968df76ef9ec088b68ec5fe536f3804e8fb0ca3 (patch)
tree0cbea0b1df4c4912bfda20eed09fd8c768e7f2c4 /SpatialDivisiveNormalization.lua
parentf6f22a3bf2ee4b920b7a38a61d0be911377f0d47 (diff)
new module interface
Diffstat (limited to 'SpatialDivisiveNormalization.lua')
-rw-r--r--SpatialDivisiveNormalization.lua100
1 files changed, 100 insertions, 0 deletions
diff --git a/SpatialDivisiveNormalization.lua b/SpatialDivisiveNormalization.lua
new file mode 100644
index 0000000..b9a2a16
--- /dev/null
+++ b/SpatialDivisiveNormalization.lua
@@ -0,0 +1,100 @@
+local DivisiveNorm, parent = torch.class('cudnn.SpatialDivisiveNormalization', 'nn.Module')
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+function DivisiveNorm:__init(size, alpha, beta, K)
+ parent.__init(self)
+ self.size = size or 5
+ self.alpha = alpha or 1e-4
+ self.beta = beta or 0.75
+ self.K = K or 2.0
+ assert(self.size >= 1 and self.size <= 16, "size has to be between 1 and 16")
+ assert(self.K >= 1e-5, "K has to be greater than 1e-5")
+ assert(self.beta >= 0.01, "Beta has to be > 0.01")
+end
+
+function DivisiveNorm:resetDescriptors()
+ -- create DivisiveNorm descriptor
+ self.DivisiveNormDesc = ffi.new('struct cudnnDivisiveNormDescriptor_t*[1]')
+ errcheck('cudnnCreateDivisiveNormDescriptor', self.DivisiveNormDesc)
+ errcheck('cudnnSetDivisiveNormDescriptor', self.DivisiveNormDesc[0], self.size,
+ self.alpha, self.beta, self.K);
+ local function destroyDesc(d)
+ errcheck('cudnnDestroyDivisiveNormDescriptor', d[0]);
+ end
+ ffi.gc(self.DivisiveNormDesc, destroyPoolDesc)
+end
+
+function DivisiveNorm:createIODescriptors(input)
+ local batch = true
+ if input:dim() == 3 then
+ input = input:view(1, input:size(1), input:size(2), input:size(3))
+ batch = false
+ end
+ assert(input:dim() == 4 and input:isContiguous());
+ if not self.iDesc or
+ input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
+ or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
+ self.iSize = input:size()
+ self.gradInput:resizeAs(input)
+ self.output:resizeAs(input)
+
+ -- create input/output descriptor
+ self.iDesc = cudnn.toDescriptor(input)
+ if not batch then
+ self.gradInput = self.gradInput:view(self.gradInput:size(2),
+ self.gradInput:size(3),
+ self.gradInput:size(4))
+ self.output = self.output:view(self.output:size(2),
+ self.output:size(3),
+ self.output:size(4))
+ end
+ end
+end
+
+local one = torch.FloatTensor({1});
+local zero = torch.FloatTensor({0});
+
+function DivisiveNorm:updateOutput(input)
+ if not self.DivisiveNormDesc then self:resetPoolDescriptors() end
+ self:createIODescriptors(input)
+ errcheck('cudnnDivisiveNormCrossChannelForward', cudnn.getHandle(),
+ self.DivisiveNormDesc[0],
+ 'CUDNN_DivisiveNorm_CROSS_CHANNEL_DIM1',
+ one:data(),
+ self.iDesc[0], input:data(),
+ zero:data(),
+ self.iDesc[0], self.output:data());
+ return self.output
+end
+
+function DivisiveNorm:updateGradInput(input, gradOutput)
+ assert(gradOutput:dim() == 3 or gradOutput:dim() == 4);
+ if not gradOutput:isContiguous() then
+ self._gradOutput = self._gradOutput or gradOutput.new()
+ self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
+ gradOutput = self._gradOutput
+ end
+ if not self.DivisiveNormDesc then self:resetPoolDescriptors() end
+ self:createIODescriptors(input)
+ errcheck('cudnnDivisiveNormCrossChannelBackward',
+ cudnn.getHandle(), self.DivisiveNormDesc[0],
+ 'CUDNN_DivisiveNorm_CROSS_CHANNEL_DIM1',
+ one:data(),
+ self.iDesc[0], self.output:data(),
+ self.iDesc[0], gradOutput:data(),
+ self.iDesc[0], input:data(),
+ zero:data(),
+ self.iDesc[0], self.gradInput:data());
+ return self.gradInput
+end
+
+function DivisiveNorm:write(f)
+ self.DivisiveNormDesc = nil
+ self.iDesc = nil
+ local var = {}
+ for k,v in pairs(self) do
+ var[k] = v
+ end
+ f:writeObject(var)
+end