Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-01-04 20:49:21 +0300
committernicholas-leonard <nick@nikopia.org>2015-01-10 00:43:06 +0300
commit517c6c0e36046f13167a4b77b06c464747ddc0fc (patch)
tree563a836e9232debec0ba88742bf8398b15ac2267
parent27acf6315e30181936a309e9831e18baec1a3f28 (diff)
WeightedEuclidean optimizations
-rw-r--r--WeightedEuclidean.lua308
-rw-r--r--test.lua93
2 files changed, 256 insertions, 145 deletions
diff --git a/WeightedEuclidean.lua b/WeightedEuclidean.lua
index 5a3af27..071203e 100644
--- a/WeightedEuclidean.lua
+++ b/WeightedEuclidean.lua
@@ -3,20 +3,12 @@ local WeightedEuclidean, parent = torch.class('nn.WeightedEuclidean', 'nn.Module
function WeightedEuclidean:__init(inputSize,outputSize)
parent.__init(self)
- self.templates = torch.Tensor(inputSize,outputSize)
- self.gradTemplates = torch.Tensor(inputSize,outputSize)
+ self.weight = torch.Tensor(inputSize,outputSize)
+ self.gradWeight = torch.Tensor(inputSize,outputSize)
-- each template (output dim) has its own diagonal covariance matrix
self.diagCov = torch.Tensor(inputSize,outputSize)
self.gradDiagCov = torch.Tensor(inputSize,outputSize)
-
- -- for compat with Torch's modules (it's bad we have to do that)
- do
- self.weight = self.templates
- self.gradWeight = self.gradTemplates
- self.bias = self.diagCov
- self.gradBias = self.gradDiagCov
- end
self:reset()
end
@@ -25,55 +17,71 @@ function WeightedEuclidean:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
- stdv = 1./math.sqrt(self.templates:size(1))
+ stdv = 1./math.sqrt(self.weight:size(1))
end
- if nn.oldSeed then
- for i=1,self.templates:size(2) do
- self.templates:select(2, i):apply(function()
- return torch.uniform(-stdv, stdv)
- end)
- end
+ self.weight:uniform(-stdv, stdv)
+ self.diagCov:fill(1)
+end
+
+local function view(res, src, ...)
+ local args = {...}
+ if src:isContiguous() then
+ res:view(src, unpack(args))
else
- self.templates:uniform(-stdv, stdv)
+ res:reshape(src, unpack(args))
end
- self.diagCov:fill(1)
end
function WeightedEuclidean:updateOutput(input)
-- lazy-initialize
- self._temp = self._temp or self.output.new()
- self._ones = self._ones or self.output.new{1}
self._diagCov = self._diagCov or self.output.new()
+
+ self._input = self._input or input.new()
+ self._weight = self._weight or self.weight.new()
+ self._expand = self._expand or self.output.new()
+ self._expand2 = self._expand or self.output.new()
+ self._expand3 = self._expand3 or self.output.new()
self._repeat = self._repeat or self.output.new()
- self._sum = self._sum or self.output.new()
- self._temp:resizeAs(input)
+ self._repeat2 = self._repeat2 or self.output.new()
+ self._repeat3 = self._repeat3 or self.output.new()
+
+ local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
+
+ -- y_j = || c_j * (w_j - x) ||
if input:dim() == 1 then
- self.output:resize(self.templates:size(2))
- for outIdx = 1,self.templates:size(2) do
- self._temp:copy(input):add(-1,self.templates:select(2,outIdx))
- self._temp:cmul(self._temp)
- local diagCov = self.diagCov:select(2,outIdx)
- self._temp:cmul(diagCov):cmul(diagCov)
- self.output[outIdx] = math.sqrt(self._temp:sum())
- end
+ view(self._input, input, inputSize, 1)
+ self._expand:expandAs(self._input, self.weight)
+ self._repeat:resizeAs(self._expand):copy(self._expand)
+ self._repeat:add(-1, self.weight)
+ self._repeat:cmul(self.diagCov)
+ self.output:norm(self._repeat, 2, 1)
+ self.output:resize(outputSize)
elseif input:dim() == 2 then
- self.output:resize(input:size(1), self.templates:size(2))
- if self._ones:size(1) ~= input:size(1) then
- self._ones:resize(input:size(1)):fill(1)
- end
- for outIdx = 1,self.templates:size(2) do
- self._temp:copy(input)
- self._temp:addr(-1, self._ones, self.templates:select(2,outIdx))
- self._temp:cmul(self._temp)
- local diagCov = self.diagCov:select(2,outIdx)
- self._diagCov:resizeAs(diagCov):copy(diagCov)
- self._diagCov:pow(2)
- self._diagCov:resize(1,self._diagCov:size(1))
- self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
- self._temp:cmul(self._temp, self._repeat)
- self._sum:sum(self._temp, 2):sqrt()
- self.output:select(2,outIdx):copy(self._sum)
+ local batchSize = input:size(1)
+
+ view(self._input, input, batchSize, inputSize, 1)
+ self._expand:expand(self._input, batchSize, inputSize, outputSize)
+ -- make the expanded tensor contiguous (requires lots of memory)
+ self._repeat:resizeAs(self._expand):copy(self._expand)
+
+ self._weight:view(self.weight, 1, inputSize, outputSize)
+ self._expand2:expandAs(self._weight, self._repeat)
+
+ self._diagCov:view(self.diagCov, 1, inputSize, outputSize)
+ self._expand3:expandAs(self._diagCov, self._repeat)
+ if torch.type(input) == 'torch.CudaTensor' then
+ -- requires lots of memory, but minimizes cudaMallocs and loops
+ self._repeat2:resizeAs(self._expand2):copy(self._expand2)
+ self._repeat:add(-1, self._repeat2)
+ self._repeat3:resizeAs(self._expand3):copy(self._expand3)
+ self._repeat:cmul(self._repeat3)
+ else
+ self._repeat:add(-1, self._expand2)
+ self._repeat:cmul(self._expand3)
end
+
+ self.output:norm(self._repeat, 2, 2)
+ self.output:resize(batchSize, outputSize)
else
error"1D or 2D input expected"
end
@@ -81,96 +89,156 @@ function WeightedEuclidean:updateOutput(input)
end
function WeightedEuclidean:updateGradInput(input, gradOutput)
- self._gradTemp = self._gradTemp or self.output.new()
- self.gradInput:resizeAs(input):zero()
- self._temp:resizeAs(input)
- self._gradTemp:cdiv(gradOutput, self.output)
+ if not self.gradInput then
+ return
+ end
+
+ self._div = self._div or input.new()
+ self._output = self._output or self.output.new()
+ self._expand4 = self._expand4 or input.new()
+ self._gradOutput = self._gradOutput or input.new()
+
+ if not self.fastBackward then
+ self:updateOutput(input)
+ end
+
+ local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
+
+ --[[
+ dy_j -2 * c_j * c_j * (w_j - x) c_j * c_j * (x - w_j)
+ ---- = -------------------------- = ---------------------
+ dx 2 || c_j * (w_j - x) || y_j
+ --]]
+
+ -- to prevent div by zero (NaN) bugs
+ self._output:resizeAs(self.output):copy(self.output):add(0.0000001)
+ view(self._gradOutput, gradOutput, gradOutput:size())
+ self._div:cdiv(gradOutput, self._output)
if input:dim() == 1 then
- for outIdx = 1,self.templates:size(2) do
- self._temp:copy(input):add(-1,self.templates:select(2,outIdx))
- local diagCov = self.diagCov:select(2,outIdx)
- self._temp:cmul(diagCov):cmul(diagCov)
-
- self._temp:mul(self._gradTemp[outIdx])
- self.gradInput:add(self._temp)
+ self._div:resize(1, outputSize)
+ self._expand4:expandAs(self._div, self.weight)
+
+ if torch.type(input) == 'torch.CudaTensor' then
+ self._repeat2:resizeAs(self._expand4):copy(self._expand4)
+ self._repeat2:cmul(self._repeat)
+ else
+ self._repeat2:cmul(self._repeat, self._expand4)
end
+
+ self._repeat2:cmul(self.diagCov)
+ self.gradInput:sum(self._repeat2, 2)
+ self.gradInput:resizeAs(input)
elseif input:dim() == 2 then
- if self._ones:size(1) ~= input:size(1) then
- self._ones:resize(input:size(1)):fill(1)
- end
- for outIdx = 1,self.templates:size(2) do
- self._temp:copy(input)
- self._temp:addr(-1, self._ones, self.templates:select(2,outIdx))
- local diagCov = self.diagCov:select(2,outIdx)
- self._diagCov:resizeAs(diagCov):copy(diagCov)
- self._diagCov:pow(2)
- self._diagCov:resize(1,self._diagCov:size(1))
- self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
- self._temp:cmul(self._temp, self._repeat)
-
- local gradTemp = self._gradTemp:select(2, outIdx)
- gradTemp = gradTemp:reshape(1,gradTemp:size(1))
- self._repeat:repeatTensor(gradTemp, input:size(2), 1)
- self.gradInput:addcmul(1, self._temp, self._repeat)
+ local batchSize = input:size(1)
+
+ self._div:resize(batchSize, 1, outputSize)
+ self._expand4:expand(self._div, batchSize, inputSize, outputSize)
+
+ if torch.type(input) == 'torch.CudaTensor' then
+ self._repeat2:resizeAs(self._expand4):copy(self._expand4)
+ self._repeat2:cmul(self._repeat)
+ self._repeat2:cmul(self._repeat3)
+ else
+ self._repeat2:cmul(self._repeat, self._expand4)
+ self._repeat2:cmul(self._expand3)
end
+
+ self.gradInput:sum(self._repeat2, 3)
+ self.gradInput:resizeAs(input)
else
error"1D or 2D input expected"
end
+
return self.gradInput
end
function WeightedEuclidean:accGradParameters(input, gradOutput, scale)
+ local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
scale = scale or 1
- self._temp:resizeAs(input)
- self._gradTemp:cdiv(gradOutput, self.output)
+
+ --[[
+ dy_j 2 * c_j * c_j * (w_j - x) c_j * c_j * (w_j - x)
+ ---- = ------------------------- = ---------------------
+ dw_j 2 || c_j * (w_j - x) || y_j
+
+ dy_j 2 * c_j * (w_j - x)^2 c_j * (w_j - x)^2
+ ---- = ----------------------- = -----------------
+ dc_j 2 || c_j * (w_j - x) || y_j
+ --]]
+ -- assumes a preceding call to updateGradInput
if input:dim() == 1 then
- for outIdx = 1,self.templates:size(2) do
- self._temp:copy(self.templates:select(2,outIdx)):add(-1,input)
- local diagCov = self.diagCov:select(2,outIdx)
- self._temp:cmul(diagCov):cmul(diagCov)
-
- self._temp:mul(self._gradTemp[outIdx])
- self.gradTemplates:select(2,outIdx):add(scale, self._temp)
-
- self._temp:copy(self.templates:select(2,outIdx)):add(-1,input)
- self._temp:pow(2)
- self._temp:cmul(self.diagCov:select(2,outIdx))
- self._temp:mul(self._gradTemp[outIdx])
- self.gradDiagCov:select(2,outIdx):add(scale, self._temp)
+ self.gradWeight:add(-scale, self._repeat2)
+
+ self._repeat:cdiv(self.diagCov)
+ self._repeat:cmul(self._repeat)
+ self._repeat:cmul(self.diagCov)
+
+ if torch.type(input) == 'torch.CudaTensor' then
+ self._repeat2:resizeAs(self._expand4):copy(self._expand4)
+ self._repeat2:cmul(self._repeat)
+ else
+ self._repeat2:cmul(self._repeat, self._expand4)
end
+
+ self.gradDiagCov:add(self._repeat2)
elseif input:dim() == 2 then
- for outIdx = 1,self.templates:size(2) do
- -- gradTemplates
- self._temp:copy(input)
- self._temp:addr(-1, self._ones, self.templates:select(2,outIdx))
- local diagCov = self.diagCov:select(2,outIdx)
- self._diagCov:resizeAs(diagCov):copy(diagCov)
- self._diagCov:pow(2)
- self._diagCov:resize(1,self._diagCov:size(1))
- self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
- self._temp:cmul(self._temp, self._repeat)
-
- local gradTemp = self._gradTemp:select(2, outIdx)
- gradTemp = gradTemp:reshape(1,gradTemp:size(1))
- self._repeat:repeatTensor(gradTemp, input:size(2), 1)
- self._temp:cmul(self._repeat)
- self._sum:sum(self._temp, 1)
- self.gradTemplates:select(2,outIdx):add(scale, self._sum)
-
- -- gradDiagCov
- local template = self.templates:select(2,outIdx)
- template = template:reshape(1, template:size(1))
- self._temp:repeatTensor(template, input:size(1), 1)
- self._temp:add(-1,input)
- self._temp:pow(2)
- self._temp:cmul(self._repeat)
- diagCov = diagCov:reshape(1, diagCov:size(1))
- self._repeat:repeatTensor(self._diagCov, input:size(1), 1)
- self._temp:cmul(self._repeat)
- self._sum:sum(self._temp, 1)
- self.gradDiagCov:select(2,outIdx):add(scale, self._sum)
+ self._sum = self._sum or input.new()
+ self._sum:sum(self._repeat2, 1)
+ self._sum:resize(inputSize, outputSize)
+ self.gradWeight:add(-scale, self._sum)
+
+ if torch.type(input) == 'torch.CudaTensor' then
+ -- requires lots of memory, but minimizes cudaMallocs and loops
+ self._repeat:cdiv(self._repeat3)
+ self._repeat:cmul(self._repeat)
+ self._repeat:cmul(self._repeat3)
+ self._repeat2:resizeAs(self._expand4):copy(self._expand4)
+ self._repeat:cmul(self._repeat2)
+ else
+ self._repeat:cdiv(self._expand3)
+ self._repeat:cmul(self._repeat)
+ self._repeat:cmul(self._expand3)
+ self._repeat:cmul(self._expand4)
end
+
+ self._sum:sum(self._repeat, 1)
+ self._sum:resize(inputSize, outputSize)
+ self.gradDiagCov:add(scale, self._sum)
else
error"1D or 2D input expected"
end
end
+
+function WeightedEuclidean:type(type)
+ if type then
+ -- prevent premature memory allocations
+ self._input = nil
+ self._output = nil
+ self._gradOutput = nil
+ self._weight = nil
+ self._div = nil
+ self._sum = nil
+ self._expand = nil
+ self._expand2 = nil
+ self._expand3 = nil
+ self._expand4 = nil
+ self._repeat = nil
+ self._repeat2 = nil
+ self._repeat3 = nil
+ end
+ return parent.type(self, type)
+end
+
+function WeightedEuclidean:parameters()
+ return {self.weight, self.diagCov}, {self.gradWeight, self.gradDiagCov}
+end
+
+function WeightedEuclidean:accUpdateGradParameters(input, gradOutput, lr)
+ local gradWeight = self.gradWeight
+ local gradDiagCov = self.gradDiagCov
+ self.gradWeight = self.weight
+ self.gradDiagCov = self.diagCov
+ self:accGradParameters(input, gradOutput, -lr)
+ self.gradWeight = gradWeight
+ self.gradDiagCov = gradDiagCov
+end
diff --git a/test.lua b/test.lua
index 27c3dde..ad48648 100644
--- a/test.lua
+++ b/test.lua
@@ -495,44 +495,87 @@ function nntest.Euclidean()
end
function nntest.WeightedEuclidean()
- local ini = math.random(3,5)
- local inj = math.random(13,5)
- local input = torch.Tensor(ini):zero()
+ local ini = math.random(5,7)
+ local inj = math.random(5,7)
+ local input = torch.randn(ini)
+ local gradOutput = torch.randn(inj)
local module = nn.WeightedEuclidean(ini,inj)
+ local output = module:forward(input):clone()
+
+ local output2 = torch.Tensor(inj):zero()
+ local temp = input:clone()
+ for o = 1,module.weight:size(2) do
+ temp:copy(input):add(-1,module.weight:select(2,o))
+ temp:cmul(temp)
+ temp:cmul(module.diagCov:select(2,o)):cmul(module.diagCov:select(2,o))
+ output2[o] = math.sqrt(temp:sum())
+ end
+ mytester:assertTensorEq(output, output2, 0.000001, 'WeightedEuclidean forward 1D err')
+
+ local input2 = torch.randn(8, ini)
+ input2[2]:copy(input)
+ local output2 = module:forward(input2)
+ mytester:assertTensorEq(output2[2], output, 0.000001, 'WeightedEuclidean forward 2D err')
+
+ local output = module:forward(input):clone()
+ module:zeroGradParameters()
+ local gradInput = module:backward(input, gradOutput, 1):clone()
+ local gradInput2 = torch.zeros(ini)
+ for o = 1,module.weight:size(2) do
+ temp:copy(input)
+ temp:add(-1,module.weight:select(2,o))
+ temp:cmul(module.diagCov:select(2,o)):cmul(module.diagCov:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradInput2:add(temp)
+ end
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'WeightedEuclidean updateGradInput 1D err')
+
+ local gradWeight = module.gradWeight:clone():zero()
+ local gradDiagCov = module.gradDiagCov:clone():zero()
+ for o = 1,module.weight:size(2) do
+ if output[o] ~= 0 then
+ temp:copy(module.weight:select(2,o)):add(-1,input)
+ temp:cmul(module.diagCov:select(2,o)):cmul(module.diagCov:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradWeight:select(2,o):add(temp)
+
+ temp:copy(module.weight:select(2,o)):add(-1,input)
+ temp:cmul(temp)
+ temp:cmul(module.diagCov:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradDiagCov:select(2,o):add(temp)
+ end
+ end
+ mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 1D err')
+ mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 1D err')
+
+ local input2 = input:view(1, -1):repeatTensor(8, 1)
+ local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1)
+ local output2 = module:forward(input2)
+ module:zeroGradParameters()
+ local gradInput2 = module:backward(input2, gradOutput2, 1/8)
+ mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'WeightedEuclidean updateGradInput 2D err')
+
+ mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 2D err')
+ mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 2D err')
+
+ input:zero()
+ module.fastBackward = false
+
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
mytester:assertlt(err,precision, 'error on weight ')
- local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ local err = jac.testJacobianParameters(module, input, module.diagCov, module.gradDiagCov)
mytester:assertlt(err,precision, 'error on bias ')
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
- -- test batch
- local bs = math.random(3,5)
- input:uniform(0,1)
- local output = module:forward(input):clone()
- module:zeroGradParameters()
- local gradInput = module:backward(input, output):clone()
- local params, gradParams = module:parameters()
- for i=1,#params do
- params[i] = params[i]:clone()
- end
- local input2 = input:view(1, -1):repeatTensor(bs, 1)
- local output2 = module:forward(input2)
- module:zeroGradParameters()
- local gradInput2 = module:backward(input2, output2, 1/bs)
- local params2, gradParams2 = module:parameters()
- mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput")
- mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput")
- mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)")
- mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)")
-
input:zero()
module:zeroGradParameters()
local err = jac.testJacobian(module,input)
@@ -541,7 +584,7 @@ function nntest.WeightedEuclidean()
local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
mytester:assertlt(err,precision, 'error on weight ')
- local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ local err = jac.testJacobianParameters(module, input, module.diagCov, module.gradDiagCov)
mytester:assertlt(err,precision, 'error on bias ')
local ferr,berr = jac.testIO(module,input2)