diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-07 09:18:11 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-07 09:18:11 +0300 |
commit | 75fd6934fa876dc3ee191fbf29e9ad689af037f4 (patch) | |
tree | fae2bb8098f90de4760098a11bd2d987e84733be | |
parent | 8c896314e9aa8540132f32d8e2577a57c35f39cd (diff) | |
parent | 4e0a96d801060121521ccc46f7294aeb3b247965 (diff) |
Merge branch 'master' of github.com:torch/nn
-rw-r--r-- | CMul.lua | 43 | ||||
-rw-r--r-- | ClassNLLCriterion.lua | 22 | ||||
-rw-r--r-- | Concat.lua | 66 | ||||
-rw-r--r-- | Container.lua | 80 | ||||
-rw-r--r-- | Jacobian.lua | 52 | ||||
-rw-r--r-- | Linear.lua | 16 | ||||
-rw-r--r-- | Mul.lua | 10 | ||||
-rw-r--r-- | Parallel.lua | 41 | ||||
-rw-r--r-- | Sequential.lua | 72 | ||||
-rw-r--r-- | SparseLinear.lua | 37 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 4 | ||||
-rw-r--r-- | WeightedEuclidean.lua | 151 | ||||
-rw-r--r-- | doc/simple.md | 26 | ||||
-rw-r--r-- | generic/SparseLinear.c | 249 | ||||
-rw-r--r-- | generic/VolumetricConvolution.c | 160 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua | 82 |
17 files changed, 748 insertions, 364 deletions
@@ -13,25 +13,56 @@ function CMul:__init(inputSize) self:reset() end -function CMul:reset() - self.weight:fill(1) +function CMul:reset(stdv) + if stdv then + stdv = stdv * math.sqrt(3) + else + stdv = 1./math.sqrt(self.weight:size(1)) + end + self.weight:uniform(-stdv,stdv) end function CMul:updateOutput(input) - self.output:copy(input); - self.output:cmul(self.weight); + self.output:resizeAs(input):copy(input) + if input:nElement() == self.weight:nElement() then + self.output:view(-1):cmul(self.weight:view(-1)); + else + if input:isSameSizeAs(self.weight) then + self.output:cmul(self.weight) + else + local batchSize = input:size(1) + self.output:view(batchSize, -1):cmul(self.weight:view(1,-1):expandAs(input:view(batchSize, -1))) + end + end return self.output end function CMul:updateGradInput(input, gradOutput) if self.gradInput then + local nElement = self.gradInput:nElement() self.gradInput:resizeAs(input) self.gradInput:zero() - self.gradInput:addcmul(1, self.weight, gradOutput) + if self.weight:nElement() == gradOutput:nElement() then + self.gradInput:addcmul(1, self.weight, gradOutput) + else + local gradOutput = gradOutput:view(input:size(1), -1) + local gradInput = self.gradInput:view(input:size(1), -1) + gradInput:addcmul(1, self.weight:view(1,-1):expandAs(gradOutput), gradOutput) + end return self.gradInput end end function CMul:accGradParameters(input, gradOutput, scale) - self.gradWeight:addcmul(scale or 1, input, gradOutput) + if self.weight:nElement() == gradOutput:nElement() then + self.gradWeight:addcmul(scale or 1, input, gradOutput) + else + local batchSize = input:size(1) + local input = input:view(batchSize, -1) + local gradOutput = gradOutput:view(batchSize, -1) + local gradWeight = self.gradWeight:view(1, -1) + for i=1,batchSize do + gradWeight:addcmul(scale or 1, input[i], gradOutput[i]) + end + end end diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 926e707..997ecef 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -12,9 +12,15 @@ end function ClassNLLCriterion:updateOutput(input, target) if input:type() == 'torch.CudaTensor' and not self.weights then - input.nn.ClassNLLCriterion_updateOutput(self, input, target) - self.output = self.outputTensor[1] - return self.output + if input:dim() == 1 then + self._target = self._target or input.new(1) + self._target[1] = target + input.nn.ClassNLLCriterion_updateOutput(self, input, self._target) + else + input.nn.ClassNLLCriterion_updateOutput(self, input, target) + end + self.output = self.outputTensor[1] + return self.output end if input:dim() == 1 then @@ -46,8 +52,14 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:zero() if input:type() == 'torch.CudaTensor' and not self.weights then - input.nn.ClassNLLCriterion_updateGradInput(self, input, target) - return self.gradInput + if input:dim() == 1 then + self._target = self._target or input.new(1) + self._target[1] = target + input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target) + else + input.nn.ClassNLLCriterion_updateGradInput(self, input, target) + end + return self.gradInput end if input:dim() == 1 then @@ -1,21 +1,11 @@ -local Concat, parent = torch.class('nn.Concat', 'nn.Module') +local Concat, parent = torch.class('nn.Concat', 'nn.Container') function Concat:__init(dimension) - parent.__init(self) - self.modules = {} + parent.__init(self, dimension) self.size = torch.LongStorage() self.dimension = dimension end -function Concat:add(module) - table.insert(self.modules, module) - return self -end - -function Concat:get(index) - return self.modules[index] -end - function Concat:updateOutput(input) local outs = {} for i=1,#self.modules do @@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr) end end -function Concat:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() - end -end - -function Concat:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function Concat:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Concat:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Concat:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Concat:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end - function Concat:__tostring__() local tab = ' ' local line = '\n' diff --git a/Container.lua b/Container.lua new file mode 100644 index 0000000..125ab98 --- /dev/null +++ b/Container.lua @@ -0,0 +1,80 @@ +-- This is code common to container modules, which are collections of +-- smaller constituent modules like Parallel, Sequential, etc. +local Container, parent = + torch.class('nn.Container', 'nn.Module') + +function Container:__init(...) + parent.__init(self, ...) + self.modules = {} +end + +function Container:add(module) + table.insert(self.modules, module) + return self +end + +function Container:get(index) + return self.modules[index] +end + +function Container:size() + return #self.modules +end + +function Container:zeroGradParameters() + for i=1,#self.modules do + self.modules[i]:zeroGradParameters() + end +end + +function Container:updateParameters(learningRate) + for _,module in ipairs(self.modules) do + module:updateParameters(learningRate) + end +end + +function Container:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function Container:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + +function Container:share(mlp, ...) + for i=1,#self.modules do + self.modules[i]:share(mlp.modules[i], ...); + end +end + +function Container:reset(stdv) + for i=1,#self.modules do + self.modules[i]:reset(stdv) + end +end + +function Container:parameters() + local function tinsert(to, from) + if type(from) == 'table' then + for i=1,#from do + tinsert(to,from[i]) + end + else + table.insert(to,from) + end + end + local w = {} + local gw = {} + for i=1,#self.modules do + local mw,mgw = self.modules[i]:parameters() + if mw then + tinsert(w,mw) + tinsert(gw,mgw) + end + end + return w,gw +end diff --git a/Jacobian.lua b/Jacobian.lua index c3797bd..25e8cf0 100644 --- a/Jacobian.lua +++ b/Jacobian.lua @@ -64,53 +64,55 @@ function nn.Jacobian.backwardUpdate(module, input, param) return jacobian end -function nn.Jacobian.forward(module, input, param) +function nn.Jacobian.forward(module, input, param, perturbation) param = param or input -- perturbation amount - local small = 1e-6 + perturbation = perturbation or 1e-6 -- 1D view of input --local tst = param:storage() local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size()) -- jacobian matrix to calculate local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement()) - + local outa = torch.Tensor(jacobian:size(2)) local outb = torch.Tensor(jacobian:size(2)) - - for i=1,sin:nElement() do - sin[i] = sin[i] - small + + for i=1,sin:nElement() do + local orig = sin[i] + sin[i] = orig - perturbation outa:copy(module:forward(input)) - sin[i] = sin[i] + 2*small + sin[i] = orig + perturbation outb:copy(module:forward(input)) - sin[i] = sin[i] - small + sin[i] = orig - outb:add(-1,outa):div(2*small) + outb:add(-1,outa):div(2*perturbation) jacobian:select(1,i):copy(outb) end return jacobian end -function nn.Jacobian.forwardUpdate(module, input, param) +function nn.Jacobian.forwardUpdate(module, input, param, perturbation) -- perturbation amount - local small = 1e-6 + perturbation = perturbation or 1e-6 -- 1D view of input --local tst = param:storage() local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size()) -- jacobian matrix to calculate local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement()) - + local outa = torch.Tensor(jacobian:size(2)) local outb = torch.Tensor(jacobian:size(2)) - - for i=1,sin:nElement() do - sin[i] = sin[i] - small + + for i=1,sin:nElement() do + local orig = sin[i] + sin[i] = orig - perturbation outa:copy(module:forward(input)) - sin[i] = sin[i] + 2*small + sin[i] = orig + perturbation outb:copy(module:forward(input)) - sin[i] = sin[i] - small + sin[i] = orig - outb:add(-1,outa):div(2*small) + outb:add(-1,outa):div(2*perturbation) jacobian:select(1,i):copy(outb) jacobian:select(1,i):mul(-1) jacobian:select(1,i):add(sin[i]) @@ -118,37 +120,37 @@ function nn.Jacobian.forwardUpdate(module, input, param) return jacobian end -function nn.Jacobian.testJacobian (module, input, minval, maxval) +function nn.Jacobian.testJacobian(module, input, minval, maxval, perturbation) minval = minval or -2 maxval = maxval or 2 local inrange = maxval - minval input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) - local jac_fprop = nn.Jacobian.forward(module,input) - local jac_bprop = nn.Jacobian.backward(module,input) + local jac_fprop = nn.Jacobian.forward(module, input, input, perturbation) + local jac_bprop = nn.Jacobian.backward(module, input) local error = jac_fprop-jac_bprop return error:abs():max() end -function nn.Jacobian.testJacobianParameters (module, input, param, dparam, minval, maxval) +function nn.Jacobian.testJacobianParameters(module, input, param, dparam, minval, maxval, perturbation) minval = minval or -2 maxval = maxval or 2 local inrange = maxval - minval input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) param:copy(torch.rand(param:nElement()):mul(inrange):add(minval)) local jac_bprop = nn.Jacobian.backward(module, input, param, dparam) - local jac_fprop = nn.Jacobian.forward(module, input, param) + local jac_fprop = nn.Jacobian.forward(module, input, param, perturbation) local error = jac_fprop - jac_bprop return error:abs():max() end -function nn.Jacobian.testJacobianUpdateParameters (module, input, param, minval, maxval) +function nn.Jacobian.testJacobianUpdateParameters(module, input, param, minval, maxval, perturbation) minval = minval or -2 maxval = maxval or 2 local inrange = maxval - minval input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) param:copy(torch.rand(param:nElement()):mul(inrange):add(minval)) local params_bprop = nn.Jacobian.backwardUpdate(module, input, param) - local params_fprop = nn.Jacobian.forwardUpdate(module, input, param) + local params_fprop = nn.Jacobian.forwardUpdate(module, input, param, perturbation) local error = params_fprop - params_bprop return error:abs():max() @@ -38,14 +38,16 @@ function Linear:updateOutput(input) elseif input:dim() == 2 then local nframe = input:size(1) local nunit = self.bias:size(1) - self.output:resize(nframe, nunit) + if not self.addBuffer or self.addBuffer:size(1) ~= nframe then + self.addBuffer = input.new(nframe):fill(1) + end if nunit == 1 then -- Special case to fix output size of 1 bug: - self.output:zero():add(self.bias[1]) + self.output:copy(self.bias:view(1,nunit):expand(#self.output)) self.output:select(2,1):addmv(1, input, self.weight:select(1,1)) else - self.output:zero():addr(1, input.new(nframe):fill(1), self.bias) + self.output:zero():addr(1, self.addBuffer, self.bias) self.output:addmm(1, input, self.weight:t()) end else @@ -78,18 +80,18 @@ function Linear:accGradParameters(input, gradOutput, scale) if input:dim() == 1 then self.gradWeight:addr(scale, gradOutput, input) - self.gradBias:add(scale, gradOutput) + self.gradBias:add(scale, gradOutput) elseif input:dim() == 2 then local nframe = input:size(1) local nunit = self.bias:size(1) - + if nunit == 1 then -- Special case to fix output size of 1 bug: self.gradWeight:select(1,1):addmv(scale, input:t(), gradOutput:select(2,1)) - self.gradBias:addmv(scale, gradOutput:t(), input.new(nframe):fill(1)) + self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) else self.gradWeight:addmm(scale, gradOutput:t(), input) - self.gradBias:addmv(scale, gradOutput:t(), input.new(nframe):fill(1)) + self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) end end @@ -1,15 +1,11 @@ local Mul, parent = torch.class('nn.Mul', 'nn.Module') -function Mul:__init(inputSize) +function Mul:__init() parent.__init(self) self.weight = torch.Tensor(1) self.gradWeight = torch.Tensor(1) - -- state - self.gradInput:resize(inputSize) - self.output:resize(inputSize) - self:reset() end @@ -25,13 +21,13 @@ function Mul:reset(stdv) end function Mul:updateOutput(input) - self.output:copy(input); + self.output:resizeAs(input):copy(input); self.output:mul(self.weight[1]); return self.output end function Mul:updateGradInput(input, gradOutput) - self.gradInput:zero() + self.gradInput:resizeAs(input):zero() self.gradInput:add(self.weight[1], gradOutput) return self.gradInput end diff --git a/Parallel.lua b/Parallel.lua index 3057ba2..ef42723 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -1,4 +1,4 @@ -local Parallel, parent = torch.class('nn.Parallel', 'nn.Module') +local Parallel, parent = torch.class('nn.Parallel', 'nn.Container') function Parallel:__init(inputDimension,outputDimension) parent.__init(self) @@ -8,15 +8,6 @@ function Parallel:__init(inputDimension,outputDimension) self.outputDimension = outputDimension end -function Parallel:add(module) - table.insert(self.modules, module) - return self -end - -function Parallel:get(index) - return self.modules[index] -end - function Parallel:updateOutput(input) local modules=input:size(self.inputDimension) @@ -99,36 +90,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) end end -function Parallel:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() - end -end - -function Parallel:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function Parallel:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Parallel:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Parallel:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - function Parallel:parameters() local function tinsert(to, from) if type(from) == 'table' then diff --git a/Sequential.lua b/Sequential.lua index 97554b3..3288e6d 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -1,9 +1,4 @@ -local Sequential, parent = torch.class('nn.Sequential', 'nn.Module') - -function Sequential:__init() - parent.__init(self) - self.modules = {} -end +local Sequential, _ = torch.class('nn.Sequential', 'nn.Container') function Sequential:add(module) if #self.modules == 0 then @@ -24,14 +19,6 @@ function Sequential:insert(module, index) self.gradInput = self.modules[1].gradInput end -function Sequential:size() - return #self.modules -end - -function Sequential:get(index) - return self.modules[index] -end - function Sequential:updateOutput(input) local currentOutput = input for i=1,#self.modules do @@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr) currentModule:accUpdateGradParameters(input, currentGradOutput, lr) end -function Sequential:zeroGradParameters() - for i=1,#self.modules do - self.modules[i]:zeroGradParameters() - end -end - -function Sequential:updateParameters(learningRate) - for i=1,#self.modules do - self.modules[i]:updateParameters(learningRate) - end -end - -function Sequential:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Sequential:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Sequential:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Sequential:reset(stdv) - for i=1,#self.modules do - self.modules[i]:reset(stdv) - end -end - -function Sequential:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end function Sequential:__tostring__() local tab = ' ' diff --git a/SparseLinear.lua b/SparseLinear.lua index 735d0ed..ca15be6 100644 --- a/SparseLinear.lua +++ b/SparseLinear.lua @@ -4,11 +4,16 @@ function SparseLinear:__init(inputSize, outputSize) parent.__init(self) self.weightDecay = 0 - self.weight = torch.Tensor(outputSize, inputSize) - self.bias = torch.Tensor(outputSize) - self.gradWeight = torch.Tensor(outputSize, inputSize) - self.gradBias = torch.Tensor(outputSize) - self.lastInput = torch.Tensor() + self.weight = torch.Tensor(outputSize, inputSize):zero() + self.bias = torch.Tensor(outputSize):zero() + self.gradWeight = torch.Tensor(outputSize, inputSize):zero() + self.gradBias = torch.Tensor(outputSize):zero() + self.lastInput = nil + + if torch.getnumthreads() > 1 and outputSize >= 128 then + self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads()) + end + -- state self.gradInput:resize(inputSize) self.output:resize(outputSize) @@ -20,7 +25,7 @@ function SparseLinear:reset(stdv) if stdv then stdv = stdv * math.sqrt(3) else - stdv = 1./math.sqrt(self.weight:size(1)) + stdv = 1./math.sqrt(self.weight:size(2)) end if nn.oldSeed then for i=1,self.weight:size(1) do @@ -40,22 +45,18 @@ function SparseLinear:updateOutput(input) end function SparseLinear:accGradParameters(input, gradOutput, scale) + if not self.lastInput then + self.lastInput = input:clone() + else + self.lastInput:resizeAs(input):copy(input) + end + return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale) end function SparseLinear:updateGradInput(input, gradOutput) if self.gradInput then - self.gradInput:resize(input:size()) - self.gradInput:copy(input) - local numNonzero = self.gradInput:size(1) - for e=1,numNonzero do - local g = 0 - local i = self.gradInput[{e,1}] - for j=1,self.output:size(1) do - g = g + self.weight[{j,i}] * gradOutput[j] - end - self.gradInput[{e,2}] = g - end + input.nn.SparseLinear_updateGradInput(self, input, gradOutput) return self.gradInput end -end
\ No newline at end of file +end diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 4dec9e3..aea94a5 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -20,7 +20,9 @@ function VolumetricConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, self.bias = torch.Tensor(nOutputPlane) self.gradWeight = torch.Tensor(nOutputPlane, nInputPlane, kT, kH, kW) self.gradBias = torch.Tensor(nOutputPlane) - + -- temporary buffers for unfolding (CUDA) + self.finput = torch.Tensor() + self.fgradInput = torch.Tensor() self:reset() end diff --git a/WeightedEuclidean.lua b/WeightedEuclidean.lua index 3808db6..5a3af27 100644 --- a/WeightedEuclidean.lua +++ b/WeightedEuclidean.lua @@ -6,13 +6,10 @@ function WeightedEuclidean:__init(inputSize,outputSize) self.templates = torch.Tensor(inputSize,outputSize) self.gradTemplates = torch.Tensor(inputSize,outputSize) + -- each template (output dim) has its own diagonal covariance matrix self.diagCov = torch.Tensor(inputSize,outputSize) self.gradDiagCov = torch.Tensor(inputSize,outputSize) - - self.gradInput:resize(inputSize) - self.output:resize(outputSize) - self.temp = torch.Tensor(inputSize) - + -- for compat with Torch's modules (it's bad we have to do that) do self.weight = self.templates @@ -43,45 +40,137 @@ function WeightedEuclidean:reset(stdv) end function WeightedEuclidean:updateOutput(input) - self.output:zero() - for o = 1,self.templates:size(2) do - self.temp:copy(input):add(-1,self.templates:select(2,o)) - self.temp:cmul(self.temp) - self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o)) - self.output[o] = math.sqrt(self.temp:sum()) + -- lazy-initialize + self._temp = self._temp or self.output.new() + self._ones = self._ones or self.output.new{1} + self._diagCov = self._diagCov or self.output.new() + self._repeat = self._repeat or self.output.new() + self._sum = self._sum or self.output.new() + self._temp:resizeAs(input) + if input:dim() == 1 then + self.output:resize(self.templates:size(2)) + for outIdx = 1,self.templates:size(2) do + self._temp:copy(input):add(-1,self.templates:select(2,outIdx)) + self._temp:cmul(self._temp) + local diagCov = self.diagCov:select(2,outIdx) + self._temp:cmul(diagCov):cmul(diagCov) + self.output[outIdx] = math.sqrt(self._temp:sum()) + end + elseif input:dim() == 2 then + self.output:resize(input:size(1), self.templates:size(2)) + if self._ones:size(1) ~= input:size(1) then + self._ones:resize(input:size(1)):fill(1) + end + for outIdx = 1,self.templates:size(2) do + self._temp:copy(input) + self._temp:addr(-1, self._ones, self.templates:select(2,outIdx)) + self._temp:cmul(self._temp) + local diagCov = self.diagCov:select(2,outIdx) + self._diagCov:resizeAs(diagCov):copy(diagCov) + self._diagCov:pow(2) + self._diagCov:resize(1,self._diagCov:size(1)) + self._repeat:repeatTensor(self._diagCov, input:size(1), 1) + self._temp:cmul(self._temp, self._repeat) + self._sum:sum(self._temp, 2):sqrt() + self.output:select(2,outIdx):copy(self._sum) + end + else + error"1D or 2D input expected" end return self.output end function WeightedEuclidean:updateGradInput(input, gradOutput) - self:forward(input) - self.gradInput:zero() - for o = 1,self.templates:size(2) do - if self.output[o] ~= 0 then - self.temp:copy(input):add(-1,self.templates:select(2,o)) - self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o)) - self.temp:mul(gradOutput[o]/self.output[o]) - self.gradInput:add(self.temp) + self._gradTemp = self._gradTemp or self.output.new() + self.gradInput:resizeAs(input):zero() + self._temp:resizeAs(input) + self._gradTemp:cdiv(gradOutput, self.output) + if input:dim() == 1 then + for outIdx = 1,self.templates:size(2) do + self._temp:copy(input):add(-1,self.templates:select(2,outIdx)) + local diagCov = self.diagCov:select(2,outIdx) + self._temp:cmul(diagCov):cmul(diagCov) + + self._temp:mul(self._gradTemp[outIdx]) + self.gradInput:add(self._temp) + end + elseif input:dim() == 2 then + if self._ones:size(1) ~= input:size(1) then + self._ones:resize(input:size(1)):fill(1) end + for outIdx = 1,self.templates:size(2) do + self._temp:copy(input) + self._temp:addr(-1, self._ones, self.templates:select(2,outIdx)) + local diagCov = self.diagCov:select(2,outIdx) + self._diagCov:resizeAs(diagCov):copy(diagCov) + self._diagCov:pow(2) + self._diagCov:resize(1,self._diagCov:size(1)) + self._repeat:repeatTensor(self._diagCov, input:size(1), 1) + self._temp:cmul(self._temp, self._repeat) + + local gradTemp = self._gradTemp:select(2, outIdx) + gradTemp = gradTemp:reshape(1,gradTemp:size(1)) + self._repeat:repeatTensor(gradTemp, input:size(2), 1) + self.gradInput:addcmul(1, self._temp, self._repeat) + end + else + error"1D or 2D input expected" end return self.gradInput end function WeightedEuclidean:accGradParameters(input, gradOutput, scale) - self:forward(input) scale = scale or 1 - for o = 1,self.templates:size(2) do - if self.output[o] ~= 0 then - self.temp:copy(self.templates:select(2,o)):add(-1,input) - self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o)) - self.temp:mul(gradOutput[o]/self.output[o]) - self.gradTemplates:select(2,o):add(scale, self.temp) + self._temp:resizeAs(input) + self._gradTemp:cdiv(gradOutput, self.output) + if input:dim() == 1 then + for outIdx = 1,self.templates:size(2) do + self._temp:copy(self.templates:select(2,outIdx)):add(-1,input) + local diagCov = self.diagCov:select(2,outIdx) + self._temp:cmul(diagCov):cmul(diagCov) + + self._temp:mul(self._gradTemp[outIdx]) + self.gradTemplates:select(2,outIdx):add(scale, self._temp) - self.temp:copy(self.templates:select(2,o)):add(-1,input) - self.temp:cmul(self.temp) - self.temp:cmul(self.diagCov:select(2,o)) - self.temp:mul(gradOutput[o]/self.output[o]) - self.gradDiagCov:select(2,o):add(scale, self.temp) + self._temp:copy(self.templates:select(2,outIdx)):add(-1,input) + self._temp:pow(2) + self._temp:cmul(self.diagCov:select(2,outIdx)) + self._temp:mul(self._gradTemp[outIdx]) + self.gradDiagCov:select(2,outIdx):add(scale, self._temp) end + elseif input:dim() == 2 then + for outIdx = 1,self.templates:size(2) do + -- gradTemplates + self._temp:copy(input) + self._temp:addr(-1, self._ones, self.templates:select(2,outIdx)) + local diagCov = self.diagCov:select(2,outIdx) + self._diagCov:resizeAs(diagCov):copy(diagCov) + self._diagCov:pow(2) + self._diagCov:resize(1,self._diagCov:size(1)) + self._repeat:repeatTensor(self._diagCov, input:size(1), 1) + self._temp:cmul(self._temp, self._repeat) + + local gradTemp = self._gradTemp:select(2, outIdx) + gradTemp = gradTemp:reshape(1,gradTemp:size(1)) + self._repeat:repeatTensor(gradTemp, input:size(2), 1) + self._temp:cmul(self._repeat) + self._sum:sum(self._temp, 1) + self.gradTemplates:select(2,outIdx):add(scale, self._sum) + + -- gradDiagCov + local template = self.templates:select(2,outIdx) + template = template:reshape(1, template:size(1)) + self._temp:repeatTensor(template, input:size(1), 1) + self._temp:add(-1,input) + self._temp:pow(2) + self._temp:cmul(self._repeat) + diagCov = diagCov:reshape(1, diagCov:size(1)) + self._repeat:repeatTensor(self._diagCov, input:size(1), 1) + self._temp:cmul(self._repeat) + self._sum:sum(self._temp, 1) + self.gradDiagCov:select(2,outIdx):add(scale, self._sum) + end + else + error"1D or 2D input expected" end end diff --git a/doc/simple.md b/doc/simple.md index aa1a94d..8cbb017 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -262,7 +262,7 @@ to produce the output _y_. <a name="nn.Mul"/> ## Mul ## -`module` = `Mul(inputDimension)` +`module` = `Mul()` Applies a _single_ scaling factor to the incoming data, i.e. _y= w x_, where _w_ is a scalar. @@ -271,7 +271,7 @@ Example: ```lua y=torch.Tensor(5); mlp=nn.Sequential() -mlp:add(nn.Mul(5)) +mlp:add(nn.Mul()) function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) @@ -387,22 +387,30 @@ then an `nxq` matrix would be output. <a name="nn.Euclidean"/> ## Euclidean ## -`module` = `Euclidean(inputDimension,outputDimension)` +`module` = `Euclidean(inputSize,outputSize)` -Outputs the Euclidean distance of the input to `outputDimension` centers, -i.e. this layer has the weights `c_i`, `i` = `1`,..,`outputDimension`, where -`c_i` are vectors of dimension `inputDimension`. Output dimension `j` is -`|| c_j - x ||`, where `x` is the input. +Outputs the Euclidean distance of the input to `outputSize` centers, +i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where +`w_j` are vectors of dimension `inputSize`. + +The distance `y_j` between center `j` and input `x` is formulated as +`y_j = || w_j - x ||`. <a name="nn.WeightedEuclidean"/> ## WeightedEuclidean ## -`module` = `WeightedEuclidean(inputDimension,outputDimension)` +`module` = `WeightedEuclidean(inputSize,outputSize)` This module is similar to [Euclidean](#nn.Euclidean), but additionally learns a separate diagonal covariance matrix across the -features of the input space for each center. +features of the input space _for each center_. + +In other words, for each of the `outputSize` centers `w_j`, there is +a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`, +where `c_j` are stored as vectors of size `inputSize`. +The distance `y_j` between center `j` and input `x` is formulated as +`y_j = || c_j * (w_j - x) ||`. <a name="nn.Identity"/> ## Identity ## diff --git a/generic/SparseLinear.c b/generic/SparseLinear.c index f39791b..b3ccbf1 100644 --- a/generic/SparseLinear.c +++ b/generic/SparseLinear.c @@ -2,6 +2,18 @@ #define TH_GENERIC_FILE "generic/SparseLinear.c" #else +static int nn_(checkInput)(THTensor* t) { + return t->nDimension == 2 && t->size[1] == 2; +} + +static int nn_(checkSize2D)(THTensor* t, long size0, long size1) { + return t->nDimension == 2 && t->size[0] == size0 && t->size[1] == size1; +} + +static int nn_(checkSize1D)(THTensor* t, long size0) { + return t->nDimension == 1 && t->size[0] == size0; +} + static int nn_(SparseLinear_updateOutput)(lua_State *L) { long i; @@ -9,27 +21,72 @@ static int nn_(SparseLinear_updateOutput)(lua_State *L) THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor); THTensor * output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); - long dim = weight->size[1]; /* number of weights.. */ + + long outDim = weight->size[0]; + long inDim = weight->size[1]; + + luaL_argcheck(L, nn_(checkInput)(input), 2, "input size must be nnz x 2"); + luaL_argcheck(L, nn_(checkSize1D)(output, outDim), 1, "output size wrong"); + luaL_argcheck(L, nn_(checkSize1D)(bias, outDim), 1, "bias size wrong"); + + lua_getfield(L, 1, "shardBuffer"); + if (!lua_isnil(L, -1)) { + THTensor *buffer = + luaT_getfieldcheckudata(L, 1, "shardBuffer", torch_Tensor); + long num_shards = buffer->size[1]; + luaL_argcheck(L, + buffer->nDimension == 2 && buffer->size[0] == outDim && + num_shards > 0, + 1, + "shardBuffer size wrong"); + + THTensor_(zero)(buffer); + #pragma omp parallel for private(i) schedule(static) num_threads(num_shards) + for (i = 0; i < input->size[0]; i++) { + int shardId = omp_get_thread_num(); + long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1; + + if (offset >= 0 && offset < inDim) { + THBlas_(axpy)(outDim, + THTensor_(get2d)(input, i, 1), + THTensor_(data)(weight) + offset * weight->stride[1], + weight->stride[0], + THTensor_(data)(buffer) + shardId * buffer->stride[1], + buffer->stride[0]); + } else { + luaL_error(L, "index out of bound. updateOutput: \ +%ld not between 1 and %ld", offset + 1, inDim); + } + } + + THTensor_(sum)(output, buffer, 1); + THTensor_(cadd)(output, bias, 1.0, output); + + lua_getfield(L, 1, "output"); + return 1; + } THTensor_(copy)(output, bias); for(i = 0; i < input->size[0]; i++) { long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1; - if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */ + if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */ { real val = THTensor_(get2d)(input, i, 1); - THBlas_(axpy)(output->size[0], - val, + THBlas_(axpy)(output->size[0], + val, THTensor_(data)(weight)+offset*weight->stride[1], - weight->stride[0], - THTensor_(data)(output), + weight->stride[0], + THTensor_(data)(output), output->stride[0]); } else { - printf("\nupdateOutput: %ld not between 1 and %ld\n", offset+1, dim); - luaL_error(L, "index out of bound"); + luaL_error(L, "index out of bound. updateOutput: \ +%ld not between 1 and %ld", offset + 1, inDim); } } + + lua_getfield(L, 1, "output"); return 1; } @@ -42,39 +99,47 @@ static int nn_(SparseLinear_accGradParameters)(lua_State *L) THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor); THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor); - THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor); real weightDecay = luaT_getfieldchecknumber(L, 1, "weightDecay"); - long dim = gradWeight->size[1]; /* number of weights.. */ - for(i = 0; i < input->size[0]; i++) + long nnz = input->size[0]; + long outDim = weight->size[0]; + long inDim = weight->size[1]; + + luaL_argcheck(L, nn_(checkInput)(input), 2, "input size must be nnz x 2"); + luaL_argcheck( + L, nn_(checkSize1D)(gradOutput, outDim), 3, "gradOutput size wrong"); + luaL_argcheck( + L, nn_(checkSize2D)(gradWeight, outDim, inDim), 1, "gradWeight size wrong"); + luaL_argcheck( + L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong"); + + #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000) + for(i = 0; i < nnz; i++) { long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1; - if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */ + if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */ { real val = scale*THTensor_(get2d)(input, i, 1); - - THBlas_(axpy)(gradOutput->size[0], - val, - THTensor_(data)(gradOutput), - gradOutput->stride[0], - THTensor_(data)(gradWeight)+offset*gradWeight->stride[1], + + THBlas_(axpy)(outDim, + val, + THTensor_(data)(gradOutput), + gradOutput->stride[0], + THTensor_(data)(gradWeight)+offset*gradWeight->stride[1], gradWeight->stride[0]); } else { - printf("\naccGradParameters: %ld not between 1 and %ld\n", offset+1, dim); - luaL_error(L, "index out of bound"); + luaL_error(L, "index out of bound. accGradParameters: \ +%ld not between 1 and %ld", offset + 1, inDim); } } - - THTensor_(cadd)(gradBias, gradBias, scale, gradOutput); - + + THTensor_(cadd)(gradBias, gradBias, scale, gradOutput); + if(weightDecay != 0) THTensor_(cadd)(gradWeight, gradWeight, weightDecay, weight); - - THTensor_(resizeAs)(lastInput, input); - THTensor_(copy)(lastInput, input); - + return 0; } @@ -85,37 +150,137 @@ int nn_(SparseLinear_updateParameters)(lua_State *L) THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor); THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor); - THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor); - THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor); - - long dim = weight->size[1]; /* number of weights.. */ + THTensor * gradWeight = luaT_getfieldcheckudata( + L, 1, "gradWeight", torch_Tensor); + THTensor * lastInput = luaT_getfieldcheckudata( + L, 1, "lastInput", torch_Tensor); + + long nnz = lastInput->size[0]; + long outDim = weight->size[0]; + long inDim = weight->size[1]; + + luaL_argcheck( + L, nn_(checkSize2D)(gradWeight, outDim, inDim), 1, "gradWeight size wrong"); + luaL_argcheck( + L, nn_(checkSize1D)(bias, outDim), 1, "bias size wrong"); + luaL_argcheck( + L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong"); + THTensor_(cadd)(bias, bias, -learningRate, gradBias); - - for(i = 0; i < lastInput->size[0]; i++) + + #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000) + for(i = 0; i < nnz; i++) { long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1; - - if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */ + + if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */ { - THBlas_(axpy)(bias->size[0], - -learningRate, - THTensor_(data)(gradWeight)+offset*gradWeight->stride[1], - gradWeight->stride[0], - THTensor_(data)(weight)+offset*weight->stride[1], + real* pGradWeight = + THTensor_(data)(gradWeight)+offset*gradWeight->stride[1]; + THBlas_(axpy)(outDim, + -learningRate, + pGradWeight, + gradWeight->stride[0], + THTensor_(data)(weight)+offset*weight->stride[1], weight->stride[0]); } else { - printf("\nupdateParameters: %ld not between 1 and %ld\n", offset+1, dim); - luaL_error(L, "index out of bound"); + luaL_error(L, "index out of bound. updateParameters: \ +%ld not between 1 and %ld", offset + 1, inDim); + } + } + return 0; +} + +int nn_(SparseLinear_zeroGradParameters)(lua_State *L) +{ + long i; + THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor); + THTensor * gradWeight = luaT_getfieldcheckudata( + L, 1, "gradWeight", torch_Tensor); + THTensor * lastInput = luaT_getfieldcheckudata( + L, 1, "lastInput", torch_Tensor); + + long nnz = lastInput->size[0]; + long outDim = gradWeight->size[0]; + long inDim = gradWeight->size[1]; + + luaL_argcheck( + L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong"); + + THTensor_(zero)(gradBias); + #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000) + for(i = 0; i < nnz; i++) + { + long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1; + + if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */ + { + real* pGradWeight = + THTensor_(data)(gradWeight)+offset*gradWeight->stride[1]; + if(gradWeight->stride[0] == 1) { + THVector_(fill)(pGradWeight, 0, outDim); + } else { + long j; + for(j = 0; j < outDim; ++j) { + pGradWeight[j * gradWeight->stride[0]] = 0; + } + } + } + else { + luaL_error(L, "index out of bound. zeroGradParameters: \ +%ld not between 1 and %ld", offset + 1, inDim); } } return 0; } +static int nn_(SparseLinear_updateGradInput)(lua_State *L) { + THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); + THTensor *gradInput = + luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor); + THTensor *input = luaT_checkudata(L, 2, torch_Tensor); + THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); + + long i; + long nnz = input->size[0]; + long outDim = weight->size[0]; + long inDim = weight->size[1]; + + luaL_argcheck( + L, nn_(checkInput)(input), 2, "input must be an nnz x 2 tensor"); + luaL_argcheck( + L, nn_(checkSize1D)(gradOutput, outDim), 3, "gradOutput size wrong"); + + THTensor_(resize2d)(gradInput, input->size[0], input->size[1]); + + #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000) + for (i = 0; i < nnz; ++i) { + long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1; + THTensor_(set2d)(gradInput, i, 0, offset + 1); + + if (offset >= 0 && offset < inDim) { + real val = + THBlas_(dot)(outDim, + THTensor_(data)(gradOutput), + gradOutput->stride[0], + THTensor_(data)(weight) + offset * weight->stride[1], + weight->stride[0]); + THTensor_(set2d)(gradInput, i, 1, val); + } else { + luaL_error(L, "index out of bound. updateGradInput: \ +%ld not between 1 and %ld", offset + 1, inDim); + } + } + return 0; +} + static const struct luaL_Reg nn_(SparseLinear__) [] = { {"SparseLinear_updateOutput", nn_(SparseLinear_updateOutput)}, {"SparseLinear_accGradParameters", nn_(SparseLinear_accGradParameters)}, {"SparseLinear_updateParameters", nn_(SparseLinear_updateParameters)}, + {"SparseLinear_zeroGradParameters", nn_(SparseLinear_zeroGradParameters)}, + {"SparseLinear_updateGradInput", nn_(SparseLinear_updateGradInput)}, {NULL, NULL} }; diff --git a/generic/VolumetricConvolution.c b/generic/VolumetricConvolution.c index feeaf05..bb30a70 100644 --- a/generic/VolumetricConvolution.c +++ b/generic/VolumetricConvolution.c @@ -13,22 +13,31 @@ static int nn_(VolumetricConvolution_updateOutput)(lua_State *L) THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor); THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); - luaL_argcheck(L, input->nDimension == 4, 2, "4D tensor expected"); - - { - long nOutputPlane = weight->size[0]; - long kT = weight->size[2]; - long kH = weight->size[3]; - long kW = weight->size[4]; - long inputDepth = input->size[1]; - long inputHeight = input->size[2]; - long inputWidth = input->size[3]; - long outputDepth = (inputDepth - kT) / dT + 1; - long outputWidth = (inputWidth - kW) / dW + 1; - long outputHeight = (inputHeight - kH) / dH + 1; - THTensor *outn = THTensor_(new)(); - long i; + luaL_argcheck(L, input->nDimension == 4 || input->nDimension == 5, + 2, "4D or 5D (batch-mode) tensor expected"); + int dimt = 1; + int dimh = 2; + int dimw = 3; + + if (input->nDimension == 5) { + dimt++; + dimh++; + dimw++; + } + long nOutputPlane = weight->size[0]; + long kT = weight->size[2]; + long kH = weight->size[3]; + long kW = weight->size[4]; + long inputDepth = input->size[dimt]; + long inputHeight = input->size[dimh]; + long inputWidth = input->size[dimw]; + long outputDepth = (inputDepth - kT) / dT + 1; + long outputWidth = (inputWidth - kW) / dW + 1; + long outputHeight = (inputHeight - kH) / dH + 1; + THTensor *outn = THTensor_(new)(); + long i,j; + if (input->nDimension == 4) { /* non-batch mode */ THTensor_(resize4d)(output, nOutputPlane, outputDepth, outputHeight, outputWidth); /* add bias */ @@ -37,18 +46,41 @@ static int nn_(VolumetricConvolution_updateOutput)(lua_State *L) THTensor_(fill)(outn, THTensor_(get1d)(bias, i)); } - THTensor_(free)(outn); - /* do convolutions */ THTensor_(conv3Dmv)(output, 1.0, 1.0, input, weight, dT, dH, dW, "V", "X"); - } + } else { /* batch mode */ + long nBatch = input->size[0]; + THTensor_(resize5d)(output, nBatch, nOutputPlane, + outputDepth, outputHeight, outputWidth); + THTensor *inb = THTensor_(new)(); + THTensor *outb = THTensor_(new)(); + + for (j=0; j<nBatch; j++) { /* loop over batches */ + THTensor_(select)(inb,input,0,j); + THTensor_(select)(outb,output,0,j); + + /* add bias */ + for (i=0; i<bias->size[0]; i++) { + THTensor_(select)(outn,outb,0,i); + THTensor_(fill)(outn, THTensor_(get1d)(bias, i)); + } + + /* do convolutions */ + THTensor_(conv3Dmv)(outb, 1.0, 1.0, inb, weight, dT, dH, dW, "V", "X"); + } + THTensor_(free)(inb); + THTensor_(free)(outb); + } + THTensor_(free)(outn); + return 1; } static int nn_(VolumetricConvolution_updateGradInput)(lua_State *L) { + THTensor *input = luaT_checkudata(L, 2, torch_Tensor); THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); int dT = luaT_getfieldcheckint(L, 1, "dT"); int dW = luaT_getfieldcheckint(L, 1, "dW"); @@ -58,12 +90,37 @@ static int nn_(VolumetricConvolution_updateGradInput)(lua_State *L) THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor); THTensor *tweight; - - THArgCheck( nOutputPlane == gradOutput->size[0], 1, "Number of output features is not equal to nOutputPlane" ); + + luaL_argcheck(L, gradOutput->nDimension == 4 || gradOutput->nDimension == 5, + 3, "4D or 5D (batch-mode) tensor expected"); + int dimPlane = 0; + if (gradOutput->nDimension == 5) { + dimPlane++; + } + THArgCheck( nOutputPlane == gradOutput->size[dimPlane], 1, + "Number of output features is not equal to nOutputPlane" ); /* gradient to input */ tweight = THTensor_(newTranspose)(weight,0,1); - THTensor_(conv3Dmv)(gradInput, 0.0, 1.0, gradOutput, tweight, dT, dH, dW, "F", "C"); + if (gradOutput->nDimension == 4) { /* non-batch mode */ + THTensor_(conv3Dmv)(gradInput, 0.0, 1.0, gradOutput, tweight, dT, dH, dW, "F", "C"); + } else { /* batch mode */ + long nBatch = gradOutput->size[0]; + THTensor *ginpb = THTensor_(new)(); + THTensor *goutb = THTensor_(new)(); + long j; + THTensor_(resize5d)(gradInput, input->size[0], input->size[1], input->size[2], + input->size[3], input->size[4]); + + for (j=0; j<nBatch; j++) { /* loop over batches */ + THTensor_(select)(ginpb,gradInput,0,j); + THTensor_(select)(goutb,gradOutput,0,j); + THTensor_(conv3Dmv)(ginpb, 0.0, 1.0, goutb, tweight, dT, dH, dW, "F", "C"); + } + THTensor_(free)(ginpb); + THTensor_(free)(goutb); + } + THTensor_(free)(tweight); return 1; @@ -72,7 +129,7 @@ static int nn_(VolumetricConvolution_updateGradInput)(lua_State *L) static int nn_(VolumetricConvolution_accGradParameters)(lua_State *L) { THTensor *input = luaT_checkudata(L, 2, torch_Tensor); - THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); + THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); real scale = luaL_optnumber(L, 4, 1); int dT = luaT_getfieldcheckint(L, 1, "dT"); int dW = luaT_getfieldcheckint(L, 1, "dW"); @@ -85,21 +142,54 @@ static int nn_(VolumetricConvolution_accGradParameters)(lua_State *L) long k; real *gradBias_data; THTensor* gradOutSlice; - - THArgCheck( nOutputPlane == gradOutput->size[0], 1, "Number of output features is not equal to nOutputPlane" ); - - /* gradient to bias */ - gradBias_data = THTensor_(data)(gradBias); - gradOutSlice = THTensor_(new)(); - for(k = 0; k < nOutputPlane; k++) - { - THTensor_(select)(gradOutSlice, gradOutput, 0, k); - gradBias_data[k] += scale*THTensor_(sumall)(gradOutSlice); + int dimPlane = 0; + if (gradOutput->nDimension == 5) { + dimPlane++; } - THTensor_(free)(gradOutSlice); + + THArgCheck( nOutputPlane == gradOutput->size[dimPlane], 1, + "Number of output features is not equal to nOutputPlane" ); - /* gradient to kernels */ - THTensor_(conv3DRevger)(gradWeight, 1.0, scale, input, gradOutput, dT, dH, dW); + + if (gradOutput->nDimension == 4) { /* non-batch mode */ + /* gradient to bias */ + gradBias_data = THTensor_(data)(gradBias); + gradOutSlice = THTensor_(new)(); + for(k = 0; k < nOutputPlane; k++) + { + THTensor_(select)(gradOutSlice, gradOutput, 0, k); + gradBias_data[k] += scale*THTensor_(sumall)(gradOutSlice); + } + THTensor_(free)(gradOutSlice); + + /* gradient to kernels */ + THTensor_(conv3DRevger)(gradWeight, 1.0, scale, input, gradOutput, dT, dH, dW); + } else { /* batch mode */ + long nBatch = gradOutput->size[0]; + THTensor *inpb = THTensor_(new)(); + THTensor *goutb = THTensor_(new)(); + long j; + + for (j=0; j<nBatch; j++) { /* loop over batches */ + THTensor_(select)(inpb,input,0,j); + THTensor_(select)(goutb,gradOutput,0,j); + + /* gradient to bias */ + gradBias_data = THTensor_(data)(gradBias); + gradOutSlice = THTensor_(new)(); + for(k = 0; k < nOutputPlane; k++) + { + THTensor_(select)(gradOutSlice, goutb, 0, k); + gradBias_data[k] += scale*THTensor_(sumall)(gradOutSlice); + } + THTensor_(free)(gradOutSlice); + + /* gradient to kernels */ + THTensor_(conv3DRevger)(gradWeight, 1.0, scale, inpb, goutb, dT, dH, dW); + } + THTensor_(free)(inpb); + THTensor_(free)(goutb); + } return 0; } @@ -4,6 +4,7 @@ require('libnn') include('ErrorMessages.lua') include('Module.lua') +include('Container.lua') include('Concat.lua') include('Parallel.lua') include('Sequential.lua') @@ -76,6 +76,7 @@ function nntest.CMul() local input = torch.Tensor(ini,inj,ink):zero() local module = nn.CMul(ini*inj*ink) + -- 1D local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -90,6 +91,26 @@ function nntest.CMul() 'error on weight [%s]', t)) end + -- 2D + local nframe = math.random(50,70) + local nframe = 5 + local input = torch.Tensor(nframe, ini,inj,ink):zero() + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format('error on weight [%s]', t)) + end + + + -- IO local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') @@ -469,6 +490,41 @@ function nntest.WeightedEuclidean() local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- test batch + local bs = math.random(3,5) + input:uniform(0,1) + local output = module:forward(input):clone() + module:zeroGradParameters() + local gradInput = module:backward(input, output):clone() + local params, gradParams = module:parameters() + for i=1,#params do + params[i] = params[i]:clone() + end + local input2 = input:view(1, -1):repeatTensor(bs, 1) + local output2 = module:forward(input2) + module:zeroGradParameters() + local gradInput2 = module:backward(input2, output2, 1/bs) + local params2, gradParams2 = module:parameters() + mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput") + mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput") + mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)") + mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)") + + input:zero() + module:zeroGradParameters() + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on bias ') + + local ferr,berr = jac.testIO(module,input2) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end local function criterionJacobianTest1D(cri, input, target) @@ -671,7 +727,7 @@ function nntest.Mul() local inj = math.random(3,5) local ink = math.random(3,5) local input = torch.Tensor(ini,inj,ink):zero() - local module = nn.Mul(ini*inj*ink) + local module = nn.Mul() local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -1242,9 +1298,9 @@ function nntest.SpatialFullConvolutionCompare() end local function batchcompare(smod, sin, plist) - local bs = torch.LongStorage(sin:size():size()+1) + local bs = torch.LongStorage(sin:dim()+1) bs[1] = 1 - for i=1,sin:size():size() do bs[i+1] = sin:size()[i] end + for i=1,sin:dim() do bs[i+1] = sin:size()[i] end local bin = torch.Tensor(bs):copy(sin) local bmod = smod:clone() @@ -1769,6 +1825,26 @@ function nntest.VolumetricConvolution() mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end +function nntest.VolumetricConvolutionBatchCompare() + local from = math.random(2,3) + local to = math.random(2,3) + local kt = math.random(3,4) + local ki = math.random(3,4) + local kj = math.random(3,4) + local st = math.random(2,3) + local si = math.random(2,3) + local sj = math.random(2,3) + local outt = math.random(3,4) + local outi = math.random(3,4) + local outj = math.random(3,4) + local int = (outt-1)*st+kt + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj) + local input = torch.randn(from, int, inj, ini) + batchcompare(module,input, {'weight','bias','gradWeight','gradBias'}) +end + function nntest.VolumetricMaxPooling() local from = math.random(2,3) local kt = math.random(3,4) |