Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-07 09:18:11 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-07 09:18:11 +0300
commit75fd6934fa876dc3ee191fbf29e9ad689af037f4 (patch)
treefae2bb8098f90de4760098a11bd2d987e84733be
parent8c896314e9aa8540132f32d8e2577a57c35f39cd (diff)
parent4e0a96d801060121521ccc46f7294aeb3b247965 (diff)
Merge branch 'master' of github.com:torch/nn
-rw-r--r--CMul.lua43
-rw-r--r--ClassNLLCriterion.lua22
-rw-r--r--Concat.lua66
-rw-r--r--Container.lua80
-rw-r--r--Jacobian.lua52
-rw-r--r--Linear.lua16
-rw-r--r--Mul.lua10
-rw-r--r--Parallel.lua41
-rw-r--r--Sequential.lua72
-rw-r--r--SparseLinear.lua37
-rw-r--r--VolumetricConvolution.lua4
-rw-r--r--WeightedEuclidean.lua151
-rw-r--r--doc/simple.md26
-rw-r--r--generic/SparseLinear.c249
-rw-r--r--generic/VolumetricConvolution.c160
-rw-r--r--init.lua1
-rw-r--r--test.lua82
17 files changed, 748 insertions, 364 deletions
diff --git a/CMul.lua b/CMul.lua
index 317365b..7cedfd2 100644
--- a/CMul.lua
+++ b/CMul.lua
@@ -13,25 +13,56 @@ function CMul:__init(inputSize)
self:reset()
end
-function CMul:reset()
- self.weight:fill(1)
+function CMul:reset(stdv)
+ if stdv then
+ stdv = stdv * math.sqrt(3)
+ else
+ stdv = 1./math.sqrt(self.weight:size(1))
+ end
+ self.weight:uniform(-stdv,stdv)
end
function CMul:updateOutput(input)
- self.output:copy(input);
- self.output:cmul(self.weight);
+ self.output:resizeAs(input):copy(input)
+ if input:nElement() == self.weight:nElement() then
+ self.output:view(-1):cmul(self.weight:view(-1));
+ else
+ if input:isSameSizeAs(self.weight) then
+ self.output:cmul(self.weight)
+ else
+ local batchSize = input:size(1)
+ self.output:view(batchSize, -1):cmul(self.weight:view(1,-1):expandAs(input:view(batchSize, -1)))
+ end
+ end
return self.output
end
function CMul:updateGradInput(input, gradOutput)
if self.gradInput then
+ local nElement = self.gradInput:nElement()
self.gradInput:resizeAs(input)
self.gradInput:zero()
- self.gradInput:addcmul(1, self.weight, gradOutput)
+ if self.weight:nElement() == gradOutput:nElement() then
+ self.gradInput:addcmul(1, self.weight, gradOutput)
+ else
+ local gradOutput = gradOutput:view(input:size(1), -1)
+ local gradInput = self.gradInput:view(input:size(1), -1)
+ gradInput:addcmul(1, self.weight:view(1,-1):expandAs(gradOutput), gradOutput)
+ end
return self.gradInput
end
end
function CMul:accGradParameters(input, gradOutput, scale)
- self.gradWeight:addcmul(scale or 1, input, gradOutput)
+ if self.weight:nElement() == gradOutput:nElement() then
+ self.gradWeight:addcmul(scale or 1, input, gradOutput)
+ else
+ local batchSize = input:size(1)
+ local input = input:view(batchSize, -1)
+ local gradOutput = gradOutput:view(batchSize, -1)
+ local gradWeight = self.gradWeight:view(1, -1)
+ for i=1,batchSize do
+ gradWeight:addcmul(scale or 1, input[i], gradOutput[i])
+ end
+ end
end
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 926e707..997ecef 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -12,9 +12,15 @@ end
function ClassNLLCriterion:updateOutput(input, target)
if input:type() == 'torch.CudaTensor' and not self.weights then
- input.nn.ClassNLLCriterion_updateOutput(self, input, target)
- self.output = self.outputTensor[1]
- return self.output
+ if input:dim() == 1 then
+ self._target = self._target or input.new(1)
+ self._target[1] = target
+ input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
+ else
+ input.nn.ClassNLLCriterion_updateOutput(self, input, target)
+ end
+ self.output = self.outputTensor[1]
+ return self.output
end
if input:dim() == 1 then
@@ -46,8 +52,14 @@ function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:zero()
if input:type() == 'torch.CudaTensor' and not self.weights then
- input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
- return self.gradInput
+ if input:dim() == 1 then
+ self._target = self._target or input.new(1)
+ self._target[1] = target
+ input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
+ else
+ input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
+ end
+ return self.gradInput
end
if input:dim() == 1 then
diff --git a/Concat.lua b/Concat.lua
index c94808d..b0436a5 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -1,21 +1,11 @@
-local Concat, parent = torch.class('nn.Concat', 'nn.Module')
+local Concat, parent = torch.class('nn.Concat', 'nn.Container')
function Concat:__init(dimension)
- parent.__init(self)
- self.modules = {}
+ parent.__init(self, dimension)
self.size = torch.LongStorage()
self.dimension = dimension
end
-function Concat:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function Concat:get(index)
- return self.modules[index]
-end
-
function Concat:updateOutput(input)
local outs = {}
for i=1,#self.modules do
@@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr)
end
end
-function Concat:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
- end
-end
-
-function Concat:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function Concat:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Concat:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Concat:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Concat:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
-
function Concat:__tostring__()
local tab = ' '
local line = '\n'
diff --git a/Container.lua b/Container.lua
new file mode 100644
index 0000000..125ab98
--- /dev/null
+++ b/Container.lua
@@ -0,0 +1,80 @@
+-- This is code common to container modules, which are collections of
+-- smaller constituent modules like Parallel, Sequential, etc.
+local Container, parent =
+ torch.class('nn.Container', 'nn.Module')
+
+function Container:__init(...)
+ parent.__init(self, ...)
+ self.modules = {}
+end
+
+function Container:add(module)
+ table.insert(self.modules, module)
+ return self
+end
+
+function Container:get(index)
+ return self.modules[index]
+end
+
+function Container:size()
+ return #self.modules
+end
+
+function Container:zeroGradParameters()
+ for i=1,#self.modules do
+ self.modules[i]:zeroGradParameters()
+ end
+end
+
+function Container:updateParameters(learningRate)
+ for _,module in ipairs(self.modules) do
+ module:updateParameters(learningRate)
+ end
+end
+
+function Container:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function Container:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
+function Container:share(mlp, ...)
+ for i=1,#self.modules do
+ self.modules[i]:share(mlp.modules[i], ...);
+ end
+end
+
+function Container:reset(stdv)
+ for i=1,#self.modules do
+ self.modules[i]:reset(stdv)
+ end
+end
+
+function Container:parameters()
+ local function tinsert(to, from)
+ if type(from) == 'table' then
+ for i=1,#from do
+ tinsert(to,from[i])
+ end
+ else
+ table.insert(to,from)
+ end
+ end
+ local w = {}
+ local gw = {}
+ for i=1,#self.modules do
+ local mw,mgw = self.modules[i]:parameters()
+ if mw then
+ tinsert(w,mw)
+ tinsert(gw,mgw)
+ end
+ end
+ return w,gw
+end
diff --git a/Jacobian.lua b/Jacobian.lua
index c3797bd..25e8cf0 100644
--- a/Jacobian.lua
+++ b/Jacobian.lua
@@ -64,53 +64,55 @@ function nn.Jacobian.backwardUpdate(module, input, param)
return jacobian
end
-function nn.Jacobian.forward(module, input, param)
+function nn.Jacobian.forward(module, input, param, perturbation)
param = param or input
-- perturbation amount
- local small = 1e-6
+ perturbation = perturbation or 1e-6
-- 1D view of input
--local tst = param:storage()
local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size())
-- jacobian matrix to calculate
local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement())
-
+
local outa = torch.Tensor(jacobian:size(2))
local outb = torch.Tensor(jacobian:size(2))
-
- for i=1,sin:nElement() do
- sin[i] = sin[i] - small
+
+ for i=1,sin:nElement() do
+ local orig = sin[i]
+ sin[i] = orig - perturbation
outa:copy(module:forward(input))
- sin[i] = sin[i] + 2*small
+ sin[i] = orig + perturbation
outb:copy(module:forward(input))
- sin[i] = sin[i] - small
+ sin[i] = orig
- outb:add(-1,outa):div(2*small)
+ outb:add(-1,outa):div(2*perturbation)
jacobian:select(1,i):copy(outb)
end
return jacobian
end
-function nn.Jacobian.forwardUpdate(module, input, param)
+function nn.Jacobian.forwardUpdate(module, input, param, perturbation)
-- perturbation amount
- local small = 1e-6
+ perturbation = perturbation or 1e-6
-- 1D view of input
--local tst = param:storage()
local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size())
-- jacobian matrix to calculate
local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement())
-
+
local outa = torch.Tensor(jacobian:size(2))
local outb = torch.Tensor(jacobian:size(2))
-
- for i=1,sin:nElement() do
- sin[i] = sin[i] - small
+
+ for i=1,sin:nElement() do
+ local orig = sin[i]
+ sin[i] = orig - perturbation
outa:copy(module:forward(input))
- sin[i] = sin[i] + 2*small
+ sin[i] = orig + perturbation
outb:copy(module:forward(input))
- sin[i] = sin[i] - small
+ sin[i] = orig
- outb:add(-1,outa):div(2*small)
+ outb:add(-1,outa):div(2*perturbation)
jacobian:select(1,i):copy(outb)
jacobian:select(1,i):mul(-1)
jacobian:select(1,i):add(sin[i])
@@ -118,37 +120,37 @@ function nn.Jacobian.forwardUpdate(module, input, param)
return jacobian
end
-function nn.Jacobian.testJacobian (module, input, minval, maxval)
+function nn.Jacobian.testJacobian(module, input, minval, maxval, perturbation)
minval = minval or -2
maxval = maxval or 2
local inrange = maxval - minval
input:copy(torch.rand(input:nElement()):mul(inrange):add(minval))
- local jac_fprop = nn.Jacobian.forward(module,input)
- local jac_bprop = nn.Jacobian.backward(module,input)
+ local jac_fprop = nn.Jacobian.forward(module, input, input, perturbation)
+ local jac_bprop = nn.Jacobian.backward(module, input)
local error = jac_fprop-jac_bprop
return error:abs():max()
end
-function nn.Jacobian.testJacobianParameters (module, input, param, dparam, minval, maxval)
+function nn.Jacobian.testJacobianParameters(module, input, param, dparam, minval, maxval, perturbation)
minval = minval or -2
maxval = maxval or 2
local inrange = maxval - minval
input:copy(torch.rand(input:nElement()):mul(inrange):add(minval))
param:copy(torch.rand(param:nElement()):mul(inrange):add(minval))
local jac_bprop = nn.Jacobian.backward(module, input, param, dparam)
- local jac_fprop = nn.Jacobian.forward(module, input, param)
+ local jac_fprop = nn.Jacobian.forward(module, input, param, perturbation)
local error = jac_fprop - jac_bprop
return error:abs():max()
end
-function nn.Jacobian.testJacobianUpdateParameters (module, input, param, minval, maxval)
+function nn.Jacobian.testJacobianUpdateParameters(module, input, param, minval, maxval, perturbation)
minval = minval or -2
maxval = maxval or 2
local inrange = maxval - minval
input:copy(torch.rand(input:nElement()):mul(inrange):add(minval))
param:copy(torch.rand(param:nElement()):mul(inrange):add(minval))
local params_bprop = nn.Jacobian.backwardUpdate(module, input, param)
- local params_fprop = nn.Jacobian.forwardUpdate(module, input, param)
+ local params_fprop = nn.Jacobian.forwardUpdate(module, input, param, perturbation)
local error = params_fprop - params_bprop
return error:abs():max()
diff --git a/Linear.lua b/Linear.lua
index 2ed8a3e..5e05c2f 100644
--- a/Linear.lua
+++ b/Linear.lua
@@ -38,14 +38,16 @@ function Linear:updateOutput(input)
elseif input:dim() == 2 then
local nframe = input:size(1)
local nunit = self.bias:size(1)
-
self.output:resize(nframe, nunit)
+ if not self.addBuffer or self.addBuffer:size(1) ~= nframe then
+ self.addBuffer = input.new(nframe):fill(1)
+ end
if nunit == 1 then
-- Special case to fix output size of 1 bug:
- self.output:zero():add(self.bias[1])
+ self.output:copy(self.bias:view(1,nunit):expand(#self.output))
self.output:select(2,1):addmv(1, input, self.weight:select(1,1))
else
- self.output:zero():addr(1, input.new(nframe):fill(1), self.bias)
+ self.output:zero():addr(1, self.addBuffer, self.bias)
self.output:addmm(1, input, self.weight:t())
end
else
@@ -78,18 +80,18 @@ function Linear:accGradParameters(input, gradOutput, scale)
if input:dim() == 1 then
self.gradWeight:addr(scale, gradOutput, input)
- self.gradBias:add(scale, gradOutput)
+ self.gradBias:add(scale, gradOutput)
elseif input:dim() == 2 then
local nframe = input:size(1)
local nunit = self.bias:size(1)
-
+
if nunit == 1 then
-- Special case to fix output size of 1 bug:
self.gradWeight:select(1,1):addmv(scale, input:t(), gradOutput:select(2,1))
- self.gradBias:addmv(scale, gradOutput:t(), input.new(nframe):fill(1))
+ self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
else
self.gradWeight:addmm(scale, gradOutput:t(), input)
- self.gradBias:addmv(scale, gradOutput:t(), input.new(nframe):fill(1))
+ self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
end
end
diff --git a/Mul.lua b/Mul.lua
index 7841470..289d83a 100644
--- a/Mul.lua
+++ b/Mul.lua
@@ -1,15 +1,11 @@
local Mul, parent = torch.class('nn.Mul', 'nn.Module')
-function Mul:__init(inputSize)
+function Mul:__init()
parent.__init(self)
self.weight = torch.Tensor(1)
self.gradWeight = torch.Tensor(1)
- -- state
- self.gradInput:resize(inputSize)
- self.output:resize(inputSize)
-
self:reset()
end
@@ -25,13 +21,13 @@ function Mul:reset(stdv)
end
function Mul:updateOutput(input)
- self.output:copy(input);
+ self.output:resizeAs(input):copy(input);
self.output:mul(self.weight[1]);
return self.output
end
function Mul:updateGradInput(input, gradOutput)
- self.gradInput:zero()
+ self.gradInput:resizeAs(input):zero()
self.gradInput:add(self.weight[1], gradOutput)
return self.gradInput
end
diff --git a/Parallel.lua b/Parallel.lua
index 3057ba2..ef42723 100644
--- a/Parallel.lua
+++ b/Parallel.lua
@@ -1,4 +1,4 @@
-local Parallel, parent = torch.class('nn.Parallel', 'nn.Module')
+local Parallel, parent = torch.class('nn.Parallel', 'nn.Container')
function Parallel:__init(inputDimension,outputDimension)
parent.__init(self)
@@ -8,15 +8,6 @@ function Parallel:__init(inputDimension,outputDimension)
self.outputDimension = outputDimension
end
-function Parallel:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function Parallel:get(index)
- return self.modules[index]
-end
-
function Parallel:updateOutput(input)
local modules=input:size(self.inputDimension)
@@ -99,36 +90,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
end
end
-function Parallel:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
- end
-end
-
-function Parallel:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function Parallel:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Parallel:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Parallel:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
function Parallel:parameters()
local function tinsert(to, from)
if type(from) == 'table' then
diff --git a/Sequential.lua b/Sequential.lua
index 97554b3..3288e6d 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -1,9 +1,4 @@
-local Sequential, parent = torch.class('nn.Sequential', 'nn.Module')
-
-function Sequential:__init()
- parent.__init(self)
- self.modules = {}
-end
+local Sequential, _ = torch.class('nn.Sequential', 'nn.Container')
function Sequential:add(module)
if #self.modules == 0 then
@@ -24,14 +19,6 @@ function Sequential:insert(module, index)
self.gradInput = self.modules[1].gradInput
end
-function Sequential:size()
- return #self.modules
-end
-
-function Sequential:get(index)
- return self.modules[index]
-end
-
function Sequential:updateOutput(input)
local currentOutput = input
for i=1,#self.modules do
@@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr)
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
end
-function Sequential:zeroGradParameters()
- for i=1,#self.modules do
- self.modules[i]:zeroGradParameters()
- end
-end
-
-function Sequential:updateParameters(learningRate)
- for i=1,#self.modules do
- self.modules[i]:updateParameters(learningRate)
- end
-end
-
-function Sequential:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Sequential:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Sequential:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Sequential:reset(stdv)
- for i=1,#self.modules do
- self.modules[i]:reset(stdv)
- end
-end
-
-function Sequential:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
function Sequential:__tostring__()
local tab = ' '
diff --git a/SparseLinear.lua b/SparseLinear.lua
index 735d0ed..ca15be6 100644
--- a/SparseLinear.lua
+++ b/SparseLinear.lua
@@ -4,11 +4,16 @@ function SparseLinear:__init(inputSize, outputSize)
parent.__init(self)
self.weightDecay = 0
- self.weight = torch.Tensor(outputSize, inputSize)
- self.bias = torch.Tensor(outputSize)
- self.gradWeight = torch.Tensor(outputSize, inputSize)
- self.gradBias = torch.Tensor(outputSize)
- self.lastInput = torch.Tensor()
+ self.weight = torch.Tensor(outputSize, inputSize):zero()
+ self.bias = torch.Tensor(outputSize):zero()
+ self.gradWeight = torch.Tensor(outputSize, inputSize):zero()
+ self.gradBias = torch.Tensor(outputSize):zero()
+ self.lastInput = nil
+
+ if torch.getnumthreads() > 1 and outputSize >= 128 then
+ self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads())
+ end
+
-- state
self.gradInput:resize(inputSize)
self.output:resize(outputSize)
@@ -20,7 +25,7 @@ function SparseLinear:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
- stdv = 1./math.sqrt(self.weight:size(1))
+ stdv = 1./math.sqrt(self.weight:size(2))
end
if nn.oldSeed then
for i=1,self.weight:size(1) do
@@ -40,22 +45,18 @@ function SparseLinear:updateOutput(input)
end
function SparseLinear:accGradParameters(input, gradOutput, scale)
+ if not self.lastInput then
+ self.lastInput = input:clone()
+ else
+ self.lastInput:resizeAs(input):copy(input)
+ end
+
return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale)
end
function SparseLinear:updateGradInput(input, gradOutput)
if self.gradInput then
- self.gradInput:resize(input:size())
- self.gradInput:copy(input)
- local numNonzero = self.gradInput:size(1)
- for e=1,numNonzero do
- local g = 0
- local i = self.gradInput[{e,1}]
- for j=1,self.output:size(1) do
- g = g + self.weight[{j,i}] * gradOutput[j]
- end
- self.gradInput[{e,2}] = g
- end
+ input.nn.SparseLinear_updateGradInput(self, input, gradOutput)
return self.gradInput
end
-end \ No newline at end of file
+end
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index 4dec9e3..aea94a5 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -20,7 +20,9 @@ function VolumetricConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT,
self.bias = torch.Tensor(nOutputPlane)
self.gradWeight = torch.Tensor(nOutputPlane, nInputPlane, kT, kH, kW)
self.gradBias = torch.Tensor(nOutputPlane)
-
+ -- temporary buffers for unfolding (CUDA)
+ self.finput = torch.Tensor()
+ self.fgradInput = torch.Tensor()
self:reset()
end
diff --git a/WeightedEuclidean.lua b/WeightedEuclidean.lua
index 3808db6..5a3af27 100644
--- a/WeightedEuclidean.lua
+++ b/WeightedEuclidean.lua
@@ -6,13 +6,10 @@ function WeightedEuclidean:__init(inputSize,outputSize)
self.templates = torch.Tensor(inputSize,outputSize)
self.gradTemplates = torch.Tensor(inputSize,outputSize)
+ -- each template (output dim) has its own diagonal covariance matrix
self.diagCov = torch.Tensor(inputSize,outputSize)
self.gradDiagCov = torch.Tensor(inputSize,outputSize)
-
- self.gradInput:resize(inputSize)
- self.output:resize(outputSize)
- self.temp = torch.Tensor(inputSize)
-
+
-- for compat with Torch's modules (it's bad we have to do that)
do
self.weight = self.templates
@@ -43,45 +40,137 @@ function WeightedEuclidean:reset(stdv)
end
function WeightedEuclidean:updateOutput(input)
- self.output:zero()
- for o = 1,self.templates:size(2) do
- self.temp:copy(input):add(-1,self.templates:select(2,o))
- self.temp:cmul(self.temp)
- self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o))
- self.output[o] = math.sqrt(self.temp:sum())
+ -- lazy-initialize
+ self._temp = self._temp or self.output.new()
+ self._ones = self._ones or self.output.new{1}
+ self._diagCov = self._diagCov or self.output.new()
+ self._repeat = self._repeat or self.output.new()
+ self._sum = self._sum or self.output.new()
+ self._temp:resizeAs(input)
+ if input:dim() == 1 then
+ self.output:resize(self.templates:size(2))
+ for outIdx = 1,self.templates:size(2) do
+ self._temp:copy(input):add(-1,self.templates:select(2,outIdx))
+ self._temp:cmul(self._temp)
+ local diagCov = self.diagCov:select(2,outIdx)
+ self._temp:cmul(diagCov):cmul(diagCov)
+ self.output[outIdx] = math.sqrt(self._temp:sum())
+ end
+ elseif input:dim() == 2 then
+ self.output:resize(input:size(1), self.templates:size(2))
+ if self._ones:size(1) ~= input:size(1) then
+ self._ones:resize(input:size(1)):fill(1)
+ end
+ for outIdx = 1,self.templates:size(2) do
+ self._temp:copy(input)
+ self._temp:addr(-1, self._ones, self.templates:select(2,outIdx))
+ self._temp:cmul(self._temp)
+ local diagCov = self.diagCov:select(2,outIdx)
+ self._diagCov:resizeAs(diagCov):copy(diagCov)
+ self._diagCov:pow(2)
+ self._diagCov:resize(1,self._diagCov:size(1))
+ self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
+ self._temp:cmul(self._temp, self._repeat)
+ self._sum:sum(self._temp, 2):sqrt()
+ self.output:select(2,outIdx):copy(self._sum)
+ end
+ else
+ error"1D or 2D input expected"
end
return self.output
end
function WeightedEuclidean:updateGradInput(input, gradOutput)
- self:forward(input)
- self.gradInput:zero()
- for o = 1,self.templates:size(2) do
- if self.output[o] ~= 0 then
- self.temp:copy(input):add(-1,self.templates:select(2,o))
- self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o))
- self.temp:mul(gradOutput[o]/self.output[o])
- self.gradInput:add(self.temp)
+ self._gradTemp = self._gradTemp or self.output.new()
+ self.gradInput:resizeAs(input):zero()
+ self._temp:resizeAs(input)
+ self._gradTemp:cdiv(gradOutput, self.output)
+ if input:dim() == 1 then
+ for outIdx = 1,self.templates:size(2) do
+ self._temp:copy(input):add(-1,self.templates:select(2,outIdx))
+ local diagCov = self.diagCov:select(2,outIdx)
+ self._temp:cmul(diagCov):cmul(diagCov)
+
+ self._temp:mul(self._gradTemp[outIdx])
+ self.gradInput:add(self._temp)
+ end
+ elseif input:dim() == 2 then
+ if self._ones:size(1) ~= input:size(1) then
+ self._ones:resize(input:size(1)):fill(1)
end
+ for outIdx = 1,self.templates:size(2) do
+ self._temp:copy(input)
+ self._temp:addr(-1, self._ones, self.templates:select(2,outIdx))
+ local diagCov = self.diagCov:select(2,outIdx)
+ self._diagCov:resizeAs(diagCov):copy(diagCov)
+ self._diagCov:pow(2)
+ self._diagCov:resize(1,self._diagCov:size(1))
+ self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
+ self._temp:cmul(self._temp, self._repeat)
+
+ local gradTemp = self._gradTemp:select(2, outIdx)
+ gradTemp = gradTemp:reshape(1,gradTemp:size(1))
+ self._repeat:repeatTensor(gradTemp, input:size(2), 1)
+ self.gradInput:addcmul(1, self._temp, self._repeat)
+ end
+ else
+ error"1D or 2D input expected"
end
return self.gradInput
end
function WeightedEuclidean:accGradParameters(input, gradOutput, scale)
- self:forward(input)
scale = scale or 1
- for o = 1,self.templates:size(2) do
- if self.output[o] ~= 0 then
- self.temp:copy(self.templates:select(2,o)):add(-1,input)
- self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o))
- self.temp:mul(gradOutput[o]/self.output[o])
- self.gradTemplates:select(2,o):add(scale, self.temp)
+ self._temp:resizeAs(input)
+ self._gradTemp:cdiv(gradOutput, self.output)
+ if input:dim() == 1 then
+ for outIdx = 1,self.templates:size(2) do
+ self._temp:copy(self.templates:select(2,outIdx)):add(-1,input)
+ local diagCov = self.diagCov:select(2,outIdx)
+ self._temp:cmul(diagCov):cmul(diagCov)
+
+ self._temp:mul(self._gradTemp[outIdx])
+ self.gradTemplates:select(2,outIdx):add(scale, self._temp)
- self.temp:copy(self.templates:select(2,o)):add(-1,input)
- self.temp:cmul(self.temp)
- self.temp:cmul(self.diagCov:select(2,o))
- self.temp:mul(gradOutput[o]/self.output[o])
- self.gradDiagCov:select(2,o):add(scale, self.temp)
+ self._temp:copy(self.templates:select(2,outIdx)):add(-1,input)
+ self._temp:pow(2)
+ self._temp:cmul(self.diagCov:select(2,outIdx))
+ self._temp:mul(self._gradTemp[outIdx])
+ self.gradDiagCov:select(2,outIdx):add(scale, self._temp)
end
+ elseif input:dim() == 2 then
+ for outIdx = 1,self.templates:size(2) do
+ -- gradTemplates
+ self._temp:copy(input)
+ self._temp:addr(-1, self._ones, self.templates:select(2,outIdx))
+ local diagCov = self.diagCov:select(2,outIdx)
+ self._diagCov:resizeAs(diagCov):copy(diagCov)
+ self._diagCov:pow(2)
+ self._diagCov:resize(1,self._diagCov:size(1))
+ self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
+ self._temp:cmul(self._temp, self._repeat)
+
+ local gradTemp = self._gradTemp:select(2, outIdx)
+ gradTemp = gradTemp:reshape(1,gradTemp:size(1))
+ self._repeat:repeatTensor(gradTemp, input:size(2), 1)
+ self._temp:cmul(self._repeat)
+ self._sum:sum(self._temp, 1)
+ self.gradTemplates:select(2,outIdx):add(scale, self._sum)
+
+ -- gradDiagCov
+ local template = self.templates:select(2,outIdx)
+ template = template:reshape(1, template:size(1))
+ self._temp:repeatTensor(template, input:size(1), 1)
+ self._temp:add(-1,input)
+ self._temp:pow(2)
+ self._temp:cmul(self._repeat)
+ diagCov = diagCov:reshape(1, diagCov:size(1))
+ self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
+ self._temp:cmul(self._repeat)
+ self._sum:sum(self._temp, 1)
+ self.gradDiagCov:select(2,outIdx):add(scale, self._sum)
+ end
+ else
+ error"1D or 2D input expected"
end
end
diff --git a/doc/simple.md b/doc/simple.md
index aa1a94d..8cbb017 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -262,7 +262,7 @@ to produce the output _y_.
<a name="nn.Mul"/>
## Mul ##
-`module` = `Mul(inputDimension)`
+`module` = `Mul()`
Applies a _single_ scaling factor to the incoming data, i.e.
_y= w x_, where _w_ is a scalar.
@@ -271,7 +271,7 @@ Example:
```lua
y=torch.Tensor(5);
mlp=nn.Sequential()
-mlp:add(nn.Mul(5))
+mlp:add(nn.Mul())
function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
@@ -387,22 +387,30 @@ then an `nxq` matrix would be output.
<a name="nn.Euclidean"/>
## Euclidean ##
-`module` = `Euclidean(inputDimension,outputDimension)`
+`module` = `Euclidean(inputSize,outputSize)`
-Outputs the Euclidean distance of the input to `outputDimension` centers,
-i.e. this layer has the weights `c_i`, `i` = `1`,..,`outputDimension`, where
-`c_i` are vectors of dimension `inputDimension`. Output dimension `j` is
-`|| c_j - x ||`, where `x` is the input.
+Outputs the Euclidean distance of the input to `outputSize` centers,
+i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where
+`w_j` are vectors of dimension `inputSize`.
+
+The distance `y_j` between center `j` and input `x` is formulated as
+`y_j = || w_j - x ||`.
<a name="nn.WeightedEuclidean"/>
## WeightedEuclidean ##
-`module` = `WeightedEuclidean(inputDimension,outputDimension)`
+`module` = `WeightedEuclidean(inputSize,outputSize)`
This module is similar to [Euclidean](#nn.Euclidean), but
additionally learns a separate diagonal covariance matrix across the
-features of the input space for each center.
+features of the input space _for each center_.
+
+In other words, for each of the `outputSize` centers `w_j`, there is
+a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`,
+where `c_j` are stored as vectors of size `inputSize`.
+The distance `y_j` between center `j` and input `x` is formulated as
+`y_j = || c_j * (w_j - x) ||`.
<a name="nn.Identity"/>
## Identity ##
diff --git a/generic/SparseLinear.c b/generic/SparseLinear.c
index f39791b..b3ccbf1 100644
--- a/generic/SparseLinear.c
+++ b/generic/SparseLinear.c
@@ -2,6 +2,18 @@
#define TH_GENERIC_FILE "generic/SparseLinear.c"
#else
+static int nn_(checkInput)(THTensor* t) {
+ return t->nDimension == 2 && t->size[1] == 2;
+}
+
+static int nn_(checkSize2D)(THTensor* t, long size0, long size1) {
+ return t->nDimension == 2 && t->size[0] == size0 && t->size[1] == size1;
+}
+
+static int nn_(checkSize1D)(THTensor* t, long size0) {
+ return t->nDimension == 1 && t->size[0] == size0;
+}
+
static int nn_(SparseLinear_updateOutput)(lua_State *L)
{
long i;
@@ -9,27 +21,72 @@ static int nn_(SparseLinear_updateOutput)(lua_State *L)
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
THTensor * output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- long dim = weight->size[1]; /* number of weights.. */
+
+ long outDim = weight->size[0];
+ long inDim = weight->size[1];
+
+ luaL_argcheck(L, nn_(checkInput)(input), 2, "input size must be nnz x 2");
+ luaL_argcheck(L, nn_(checkSize1D)(output, outDim), 1, "output size wrong");
+ luaL_argcheck(L, nn_(checkSize1D)(bias, outDim), 1, "bias size wrong");
+
+ lua_getfield(L, 1, "shardBuffer");
+ if (!lua_isnil(L, -1)) {
+ THTensor *buffer =
+ luaT_getfieldcheckudata(L, 1, "shardBuffer", torch_Tensor);
+ long num_shards = buffer->size[1];
+ luaL_argcheck(L,
+ buffer->nDimension == 2 && buffer->size[0] == outDim &&
+ num_shards > 0,
+ 1,
+ "shardBuffer size wrong");
+
+ THTensor_(zero)(buffer);
+ #pragma omp parallel for private(i) schedule(static) num_threads(num_shards)
+ for (i = 0; i < input->size[0]; i++) {
+ int shardId = omp_get_thread_num();
+ long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
+
+ if (offset >= 0 && offset < inDim) {
+ THBlas_(axpy)(outDim,
+ THTensor_(get2d)(input, i, 1),
+ THTensor_(data)(weight) + offset * weight->stride[1],
+ weight->stride[0],
+ THTensor_(data)(buffer) + shardId * buffer->stride[1],
+ buffer->stride[0]);
+ } else {
+ luaL_error(L, "index out of bound. updateOutput: \
+%ld not between 1 and %ld", offset + 1, inDim);
+ }
+ }
+
+ THTensor_(sum)(output, buffer, 1);
+ THTensor_(cadd)(output, bias, 1.0, output);
+
+ lua_getfield(L, 1, "output");
+ return 1;
+ }
THTensor_(copy)(output, bias);
for(i = 0; i < input->size[0]; i++)
{
long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
- if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
+ if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
{
real val = THTensor_(get2d)(input, i, 1);
- THBlas_(axpy)(output->size[0],
- val,
+ THBlas_(axpy)(output->size[0],
+ val,
THTensor_(data)(weight)+offset*weight->stride[1],
- weight->stride[0],
- THTensor_(data)(output),
+ weight->stride[0],
+ THTensor_(data)(output),
output->stride[0]);
}
else {
- printf("\nupdateOutput: %ld not between 1 and %ld\n", offset+1, dim);
- luaL_error(L, "index out of bound");
+ luaL_error(L, "index out of bound. updateOutput: \
+%ld not between 1 and %ld", offset + 1, inDim);
}
}
+
+ lua_getfield(L, 1, "output");
return 1;
}
@@ -42,39 +99,47 @@ static int nn_(SparseLinear_accGradParameters)(lua_State *L)
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
- THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
real weightDecay = luaT_getfieldchecknumber(L, 1, "weightDecay");
- long dim = gradWeight->size[1]; /* number of weights.. */
- for(i = 0; i < input->size[0]; i++)
+ long nnz = input->size[0];
+ long outDim = weight->size[0];
+ long inDim = weight->size[1];
+
+ luaL_argcheck(L, nn_(checkInput)(input), 2, "input size must be nnz x 2");
+ luaL_argcheck(
+ L, nn_(checkSize1D)(gradOutput, outDim), 3, "gradOutput size wrong");
+ luaL_argcheck(
+ L, nn_(checkSize2D)(gradWeight, outDim, inDim), 1, "gradWeight size wrong");
+ luaL_argcheck(
+ L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong");
+
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000)
+ for(i = 0; i < nnz; i++)
{
long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
- if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
+ if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
{
real val = scale*THTensor_(get2d)(input, i, 1);
-
- THBlas_(axpy)(gradOutput->size[0],
- val,
- THTensor_(data)(gradOutput),
- gradOutput->stride[0],
- THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
+
+ THBlas_(axpy)(outDim,
+ val,
+ THTensor_(data)(gradOutput),
+ gradOutput->stride[0],
+ THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
gradWeight->stride[0]);
}
else {
- printf("\naccGradParameters: %ld not between 1 and %ld\n", offset+1, dim);
- luaL_error(L, "index out of bound");
+ luaL_error(L, "index out of bound. accGradParameters: \
+%ld not between 1 and %ld", offset + 1, inDim);
}
}
-
- THTensor_(cadd)(gradBias, gradBias, scale, gradOutput);
-
+
+ THTensor_(cadd)(gradBias, gradBias, scale, gradOutput);
+
if(weightDecay != 0)
THTensor_(cadd)(gradWeight, gradWeight, weightDecay, weight);
-
- THTensor_(resizeAs)(lastInput, input);
- THTensor_(copy)(lastInput, input);
-
+
return 0;
}
@@ -85,37 +150,137 @@ int nn_(SparseLinear_updateParameters)(lua_State *L)
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
- THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
- THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
-
- long dim = weight->size[1]; /* number of weights.. */
+ THTensor * gradWeight = luaT_getfieldcheckudata(
+ L, 1, "gradWeight", torch_Tensor);
+ THTensor * lastInput = luaT_getfieldcheckudata(
+ L, 1, "lastInput", torch_Tensor);
+
+ long nnz = lastInput->size[0];
+ long outDim = weight->size[0];
+ long inDim = weight->size[1];
+
+ luaL_argcheck(
+ L, nn_(checkSize2D)(gradWeight, outDim, inDim), 1, "gradWeight size wrong");
+ luaL_argcheck(
+ L, nn_(checkSize1D)(bias, outDim), 1, "bias size wrong");
+ luaL_argcheck(
+ L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong");
+
THTensor_(cadd)(bias, bias, -learningRate, gradBias);
-
- for(i = 0; i < lastInput->size[0]; i++)
+
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000)
+ for(i = 0; i < nnz; i++)
{
long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1;
-
- if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
+
+ if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
{
- THBlas_(axpy)(bias->size[0],
- -learningRate,
- THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
- gradWeight->stride[0],
- THTensor_(data)(weight)+offset*weight->stride[1],
+ real* pGradWeight =
+ THTensor_(data)(gradWeight)+offset*gradWeight->stride[1];
+ THBlas_(axpy)(outDim,
+ -learningRate,
+ pGradWeight,
+ gradWeight->stride[0],
+ THTensor_(data)(weight)+offset*weight->stride[1],
weight->stride[0]);
}
else {
- printf("\nupdateParameters: %ld not between 1 and %ld\n", offset+1, dim);
- luaL_error(L, "index out of bound");
+ luaL_error(L, "index out of bound. updateParameters: \
+%ld not between 1 and %ld", offset + 1, inDim);
+ }
+ }
+ return 0;
+}
+
+int nn_(SparseLinear_zeroGradParameters)(lua_State *L)
+{
+ long i;
+ THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
+ THTensor * gradWeight = luaT_getfieldcheckudata(
+ L, 1, "gradWeight", torch_Tensor);
+ THTensor * lastInput = luaT_getfieldcheckudata(
+ L, 1, "lastInput", torch_Tensor);
+
+ long nnz = lastInput->size[0];
+ long outDim = gradWeight->size[0];
+ long inDim = gradWeight->size[1];
+
+ luaL_argcheck(
+ L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong");
+
+ THTensor_(zero)(gradBias);
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000)
+ for(i = 0; i < nnz; i++)
+ {
+ long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1;
+
+ if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
+ {
+ real* pGradWeight =
+ THTensor_(data)(gradWeight)+offset*gradWeight->stride[1];
+ if(gradWeight->stride[0] == 1) {
+ THVector_(fill)(pGradWeight, 0, outDim);
+ } else {
+ long j;
+ for(j = 0; j < outDim; ++j) {
+ pGradWeight[j * gradWeight->stride[0]] = 0;
+ }
+ }
+ }
+ else {
+ luaL_error(L, "index out of bound. zeroGradParameters: \
+%ld not between 1 and %ld", offset + 1, inDim);
}
}
return 0;
}
+static int nn_(SparseLinear_updateGradInput)(lua_State *L) {
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *gradInput =
+ luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
+
+ long i;
+ long nnz = input->size[0];
+ long outDim = weight->size[0];
+ long inDim = weight->size[1];
+
+ luaL_argcheck(
+ L, nn_(checkInput)(input), 2, "input must be an nnz x 2 tensor");
+ luaL_argcheck(
+ L, nn_(checkSize1D)(gradOutput, outDim), 3, "gradOutput size wrong");
+
+ THTensor_(resize2d)(gradInput, input->size[0], input->size[1]);
+
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000)
+ for (i = 0; i < nnz; ++i) {
+ long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
+ THTensor_(set2d)(gradInput, i, 0, offset + 1);
+
+ if (offset >= 0 && offset < inDim) {
+ real val =
+ THBlas_(dot)(outDim,
+ THTensor_(data)(gradOutput),
+ gradOutput->stride[0],
+ THTensor_(data)(weight) + offset * weight->stride[1],
+ weight->stride[0]);
+ THTensor_(set2d)(gradInput, i, 1, val);
+ } else {
+ luaL_error(L, "index out of bound. updateGradInput: \
+%ld not between 1 and %ld", offset + 1, inDim);
+ }
+ }
+ return 0;
+}
+
static const struct luaL_Reg nn_(SparseLinear__) [] = {
{"SparseLinear_updateOutput", nn_(SparseLinear_updateOutput)},
{"SparseLinear_accGradParameters", nn_(SparseLinear_accGradParameters)},
{"SparseLinear_updateParameters", nn_(SparseLinear_updateParameters)},
+ {"SparseLinear_zeroGradParameters", nn_(SparseLinear_zeroGradParameters)},
+ {"SparseLinear_updateGradInput", nn_(SparseLinear_updateGradInput)},
{NULL, NULL}
};
diff --git a/generic/VolumetricConvolution.c b/generic/VolumetricConvolution.c
index feeaf05..bb30a70 100644
--- a/generic/VolumetricConvolution.c
+++ b/generic/VolumetricConvolution.c
@@ -13,22 +13,31 @@ static int nn_(VolumetricConvolution_updateOutput)(lua_State *L)
THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- luaL_argcheck(L, input->nDimension == 4, 2, "4D tensor expected");
-
- {
- long nOutputPlane = weight->size[0];
- long kT = weight->size[2];
- long kH = weight->size[3];
- long kW = weight->size[4];
- long inputDepth = input->size[1];
- long inputHeight = input->size[2];
- long inputWidth = input->size[3];
- long outputDepth = (inputDepth - kT) / dT + 1;
- long outputWidth = (inputWidth - kW) / dW + 1;
- long outputHeight = (inputHeight - kH) / dH + 1;
- THTensor *outn = THTensor_(new)();
- long i;
+ luaL_argcheck(L, input->nDimension == 4 || input->nDimension == 5,
+ 2, "4D or 5D (batch-mode) tensor expected");
+ int dimt = 1;
+ int dimh = 2;
+ int dimw = 3;
+
+ if (input->nDimension == 5) {
+ dimt++;
+ dimh++;
+ dimw++;
+ }
+ long nOutputPlane = weight->size[0];
+ long kT = weight->size[2];
+ long kH = weight->size[3];
+ long kW = weight->size[4];
+ long inputDepth = input->size[dimt];
+ long inputHeight = input->size[dimh];
+ long inputWidth = input->size[dimw];
+ long outputDepth = (inputDepth - kT) / dT + 1;
+ long outputWidth = (inputWidth - kW) / dW + 1;
+ long outputHeight = (inputHeight - kH) / dH + 1;
+ THTensor *outn = THTensor_(new)();
+ long i,j;
+ if (input->nDimension == 4) { /* non-batch mode */
THTensor_(resize4d)(output, nOutputPlane, outputDepth, outputHeight, outputWidth);
/* add bias */
@@ -37,18 +46,41 @@ static int nn_(VolumetricConvolution_updateOutput)(lua_State *L)
THTensor_(fill)(outn, THTensor_(get1d)(bias, i));
}
- THTensor_(free)(outn);
-
/* do convolutions */
THTensor_(conv3Dmv)(output, 1.0, 1.0, input, weight, dT, dH, dW, "V", "X");
- }
+ } else { /* batch mode */
+ long nBatch = input->size[0];
+ THTensor_(resize5d)(output, nBatch, nOutputPlane,
+ outputDepth, outputHeight, outputWidth);
+ THTensor *inb = THTensor_(new)();
+ THTensor *outb = THTensor_(new)();
+
+ for (j=0; j<nBatch; j++) { /* loop over batches */
+ THTensor_(select)(inb,input,0,j);
+ THTensor_(select)(outb,output,0,j);
+
+ /* add bias */
+ for (i=0; i<bias->size[0]; i++) {
+ THTensor_(select)(outn,outb,0,i);
+ THTensor_(fill)(outn, THTensor_(get1d)(bias, i));
+ }
+
+ /* do convolutions */
+ THTensor_(conv3Dmv)(outb, 1.0, 1.0, inb, weight, dT, dH, dW, "V", "X");
+ }
+ THTensor_(free)(inb);
+ THTensor_(free)(outb);
+ }
+ THTensor_(free)(outn);
+
return 1;
}
static int nn_(VolumetricConvolution_updateGradInput)(lua_State *L)
{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
int dT = luaT_getfieldcheckint(L, 1, "dT");
int dW = luaT_getfieldcheckint(L, 1, "dW");
@@ -58,12 +90,37 @@ static int nn_(VolumetricConvolution_updateGradInput)(lua_State *L)
THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
THTensor *tweight;
-
- THArgCheck( nOutputPlane == gradOutput->size[0], 1, "Number of output features is not equal to nOutputPlane" );
+
+ luaL_argcheck(L, gradOutput->nDimension == 4 || gradOutput->nDimension == 5,
+ 3, "4D or 5D (batch-mode) tensor expected");
+ int dimPlane = 0;
+ if (gradOutput->nDimension == 5) {
+ dimPlane++;
+ }
+ THArgCheck( nOutputPlane == gradOutput->size[dimPlane], 1,
+ "Number of output features is not equal to nOutputPlane" );
/* gradient to input */
tweight = THTensor_(newTranspose)(weight,0,1);
- THTensor_(conv3Dmv)(gradInput, 0.0, 1.0, gradOutput, tweight, dT, dH, dW, "F", "C");
+ if (gradOutput->nDimension == 4) { /* non-batch mode */
+ THTensor_(conv3Dmv)(gradInput, 0.0, 1.0, gradOutput, tweight, dT, dH, dW, "F", "C");
+ } else { /* batch mode */
+ long nBatch = gradOutput->size[0];
+ THTensor *ginpb = THTensor_(new)();
+ THTensor *goutb = THTensor_(new)();
+ long j;
+ THTensor_(resize5d)(gradInput, input->size[0], input->size[1], input->size[2],
+ input->size[3], input->size[4]);
+
+ for (j=0; j<nBatch; j++) { /* loop over batches */
+ THTensor_(select)(ginpb,gradInput,0,j);
+ THTensor_(select)(goutb,gradOutput,0,j);
+ THTensor_(conv3Dmv)(ginpb, 0.0, 1.0, goutb, tweight, dT, dH, dW, "F", "C");
+ }
+ THTensor_(free)(ginpb);
+ THTensor_(free)(goutb);
+ }
+
THTensor_(free)(tweight);
return 1;
@@ -72,7 +129,7 @@ static int nn_(VolumetricConvolution_updateGradInput)(lua_State *L)
static int nn_(VolumetricConvolution_accGradParameters)(lua_State *L)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
real scale = luaL_optnumber(L, 4, 1);
int dT = luaT_getfieldcheckint(L, 1, "dT");
int dW = luaT_getfieldcheckint(L, 1, "dW");
@@ -85,21 +142,54 @@ static int nn_(VolumetricConvolution_accGradParameters)(lua_State *L)
long k;
real *gradBias_data;
THTensor* gradOutSlice;
-
- THArgCheck( nOutputPlane == gradOutput->size[0], 1, "Number of output features is not equal to nOutputPlane" );
-
- /* gradient to bias */
- gradBias_data = THTensor_(data)(gradBias);
- gradOutSlice = THTensor_(new)();
- for(k = 0; k < nOutputPlane; k++)
- {
- THTensor_(select)(gradOutSlice, gradOutput, 0, k);
- gradBias_data[k] += scale*THTensor_(sumall)(gradOutSlice);
+ int dimPlane = 0;
+ if (gradOutput->nDimension == 5) {
+ dimPlane++;
}
- THTensor_(free)(gradOutSlice);
+
+ THArgCheck( nOutputPlane == gradOutput->size[dimPlane], 1,
+ "Number of output features is not equal to nOutputPlane" );
- /* gradient to kernels */
- THTensor_(conv3DRevger)(gradWeight, 1.0, scale, input, gradOutput, dT, dH, dW);
+
+ if (gradOutput->nDimension == 4) { /* non-batch mode */
+ /* gradient to bias */
+ gradBias_data = THTensor_(data)(gradBias);
+ gradOutSlice = THTensor_(new)();
+ for(k = 0; k < nOutputPlane; k++)
+ {
+ THTensor_(select)(gradOutSlice, gradOutput, 0, k);
+ gradBias_data[k] += scale*THTensor_(sumall)(gradOutSlice);
+ }
+ THTensor_(free)(gradOutSlice);
+
+ /* gradient to kernels */
+ THTensor_(conv3DRevger)(gradWeight, 1.0, scale, input, gradOutput, dT, dH, dW);
+ } else { /* batch mode */
+ long nBatch = gradOutput->size[0];
+ THTensor *inpb = THTensor_(new)();
+ THTensor *goutb = THTensor_(new)();
+ long j;
+
+ for (j=0; j<nBatch; j++) { /* loop over batches */
+ THTensor_(select)(inpb,input,0,j);
+ THTensor_(select)(goutb,gradOutput,0,j);
+
+ /* gradient to bias */
+ gradBias_data = THTensor_(data)(gradBias);
+ gradOutSlice = THTensor_(new)();
+ for(k = 0; k < nOutputPlane; k++)
+ {
+ THTensor_(select)(gradOutSlice, goutb, 0, k);
+ gradBias_data[k] += scale*THTensor_(sumall)(gradOutSlice);
+ }
+ THTensor_(free)(gradOutSlice);
+
+ /* gradient to kernels */
+ THTensor_(conv3DRevger)(gradWeight, 1.0, scale, inpb, goutb, dT, dH, dW);
+ }
+ THTensor_(free)(inpb);
+ THTensor_(free)(goutb);
+ }
return 0;
}
diff --git a/init.lua b/init.lua
index 0ab72a8..c556321 100644
--- a/init.lua
+++ b/init.lua
@@ -4,6 +4,7 @@ require('libnn')
include('ErrorMessages.lua')
include('Module.lua')
+include('Container.lua')
include('Concat.lua')
include('Parallel.lua')
include('Sequential.lua')
diff --git a/test.lua b/test.lua
index 613d31c..a541a1b 100644
--- a/test.lua
+++ b/test.lua
@@ -76,6 +76,7 @@ function nntest.CMul()
local input = torch.Tensor(ini,inj,ink):zero()
local module = nn.CMul(ini*inj*ink)
+ -- 1D
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
@@ -90,6 +91,26 @@ function nntest.CMul()
'error on weight [%s]', t))
end
+ -- 2D
+ local nframe = math.random(50,70)
+ local nframe = 5
+ local input = torch.Tensor(nframe, ini,inj,ink):zero()
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.weight)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
+ mytester:assertlt(err, precision, string.format('error on weight [%s]', t))
+ end
+
+
+ -- IO
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
@@ -469,6 +490,41 @@ function nntest.WeightedEuclidean()
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- test batch
+ local bs = math.random(3,5)
+ input:uniform(0,1)
+ local output = module:forward(input):clone()
+ module:zeroGradParameters()
+ local gradInput = module:backward(input, output):clone()
+ local params, gradParams = module:parameters()
+ for i=1,#params do
+ params[i] = params[i]:clone()
+ end
+ local input2 = input:view(1, -1):repeatTensor(bs, 1)
+ local output2 = module:forward(input2)
+ module:zeroGradParameters()
+ local gradInput2 = module:backward(input2, output2, 1/bs)
+ local params2, gradParams2 = module:parameters()
+ mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput")
+ mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput")
+ mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)")
+ mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)")
+
+ input:zero()
+ module:zeroGradParameters()
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ mytester:assertlt(err,precision, 'error on bias ')
+
+ local ferr,berr = jac.testIO(module,input2)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
local function criterionJacobianTest1D(cri, input, target)
@@ -671,7 +727,7 @@ function nntest.Mul()
local inj = math.random(3,5)
local ink = math.random(3,5)
local input = torch.Tensor(ini,inj,ink):zero()
- local module = nn.Mul(ini*inj*ink)
+ local module = nn.Mul()
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
@@ -1242,9 +1298,9 @@ function nntest.SpatialFullConvolutionCompare()
end
local function batchcompare(smod, sin, plist)
- local bs = torch.LongStorage(sin:size():size()+1)
+ local bs = torch.LongStorage(sin:dim()+1)
bs[1] = 1
- for i=1,sin:size():size() do bs[i+1] = sin:size()[i] end
+ for i=1,sin:dim() do bs[i+1] = sin:size()[i] end
local bin = torch.Tensor(bs):copy(sin)
local bmod = smod:clone()
@@ -1769,6 +1825,26 @@ function nntest.VolumetricConvolution()
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.VolumetricConvolutionBatchCompare()
+ local from = math.random(2,3)
+ local to = math.random(2,3)
+ local kt = math.random(3,4)
+ local ki = math.random(3,4)
+ local kj = math.random(3,4)
+ local st = math.random(2,3)
+ local si = math.random(2,3)
+ local sj = math.random(2,3)
+ local outt = math.random(3,4)
+ local outi = math.random(3,4)
+ local outj = math.random(3,4)
+ local int = (outt-1)*st+kt
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj)
+ local input = torch.randn(from, int, inj, ini)
+ batchcompare(module,input, {'weight','bias','gradWeight','gradBias'})
+end
+
function nntest.VolumetricMaxPooling()
local from = math.random(2,3)
local kt = math.random(3,4)