diff options
author | soumith <soumith@fb.com> | 2015-08-03 13:13:29 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-08-03 13:13:29 +0300 |
commit | 7968df76ef9ec088b68ec5fe536f3804e8fb0ca3 (patch) | |
tree | 0cbea0b1df4c4912bfda20eed09fd8c768e7f2c4 /SpatialDivisiveNormalization.lua | |
parent | f6f22a3bf2ee4b920b7a38a61d0be911377f0d47 (diff) |
new module interface
Diffstat (limited to 'SpatialDivisiveNormalization.lua')
-rw-r--r-- | SpatialDivisiveNormalization.lua | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/SpatialDivisiveNormalization.lua b/SpatialDivisiveNormalization.lua new file mode 100644 index 0000000..b9a2a16 --- /dev/null +++ b/SpatialDivisiveNormalization.lua @@ -0,0 +1,100 @@ +local DivisiveNorm, parent = torch.class('cudnn.SpatialDivisiveNormalization', 'nn.Module') +local ffi = require 'ffi' +local errcheck = cudnn.errcheck + +function DivisiveNorm:__init(size, alpha, beta, K) + parent.__init(self) + self.size = size or 5 + self.alpha = alpha or 1e-4 + self.beta = beta or 0.75 + self.K = K or 2.0 + assert(self.size >= 1 and self.size <= 16, "size has to be between 1 and 16") + assert(self.K >= 1e-5, "K has to be greater than 1e-5") + assert(self.beta >= 0.01, "Beta has to be > 0.01") +end + +function DivisiveNorm:resetDescriptors() + -- create DivisiveNorm descriptor + self.DivisiveNormDesc = ffi.new('struct cudnnDivisiveNormDescriptor_t*[1]') + errcheck('cudnnCreateDivisiveNormDescriptor', self.DivisiveNormDesc) + errcheck('cudnnSetDivisiveNormDescriptor', self.DivisiveNormDesc[0], self.size, + self.alpha, self.beta, self.K); + local function destroyDesc(d) + errcheck('cudnnDestroyDivisiveNormDescriptor', d[0]); + end + ffi.gc(self.DivisiveNormDesc, destroyPoolDesc) +end + +function DivisiveNorm:createIODescriptors(input) + local batch = true + if input:dim() == 3 then + input = input:view(1, input:size(1), input:size(2), input:size(3)) + batch = false + end + assert(input:dim() == 4 and input:isContiguous()); + if not self.iDesc or + input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] + or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then + self.iSize = input:size() + self.gradInput:resizeAs(input) + self.output:resizeAs(input) + + -- create input/output descriptor + self.iDesc = cudnn.toDescriptor(input) + if not batch then + self.gradInput = self.gradInput:view(self.gradInput:size(2), + self.gradInput:size(3), + self.gradInput:size(4)) + self.output = self.output:view(self.output:size(2), + self.output:size(3), + self.output:size(4)) + end + end +end + +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + +function DivisiveNorm:updateOutput(input) + if not self.DivisiveNormDesc then self:resetPoolDescriptors() end + self:createIODescriptors(input) + errcheck('cudnnDivisiveNormCrossChannelForward', cudnn.getHandle(), + self.DivisiveNormDesc[0], + 'CUDNN_DivisiveNorm_CROSS_CHANNEL_DIM1', + one:data(), + self.iDesc[0], input:data(), + zero:data(), + self.iDesc[0], self.output:data()); + return self.output +end + +function DivisiveNorm:updateGradInput(input, gradOutput) + assert(gradOutput:dim() == 3 or gradOutput:dim() == 4); + if not gradOutput:isContiguous() then + self._gradOutput = self._gradOutput or gradOutput.new() + self._gradOutput:resizeAs(gradOutput):copy(gradOutput) + gradOutput = self._gradOutput + end + if not self.DivisiveNormDesc then self:resetPoolDescriptors() end + self:createIODescriptors(input) + errcheck('cudnnDivisiveNormCrossChannelBackward', + cudnn.getHandle(), self.DivisiveNormDesc[0], + 'CUDNN_DivisiveNorm_CROSS_CHANNEL_DIM1', + one:data(), + self.iDesc[0], self.output:data(), + self.iDesc[0], gradOutput:data(), + self.iDesc[0], input:data(), + zero:data(), + self.iDesc[0], self.gradInput:data()); + return self.gradInput +end + +function DivisiveNorm:write(f) + self.DivisiveNormDesc = nil + self.iDesc = nil + local var = {} + for k,v in pairs(self) do + var[k] = v + end + f:writeObject(var) +end |