Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2014-02-05 22:16:32 +0400
committerRonan Collobert <ronan@collobert.com>2014-02-05 22:16:32 +0400
commit08e6a9444a26432fd222e9135f0836a7879402a5 (patch)
treef7a23af8050ebf6dd73c44abeed3db53af3f8542
parentb55dbaa78b80b3845498001895bb8f1d9fa3f579 (diff)
parent14a4930523ca7c65f05a8d35787e15c97e45c74f (diff)
Merge branch 'split/nn' into standalone/nnlegacy
Conflicts: CMakeLists.txt Module.lua PairwiseDistance.lua SparseLinear.lua generic/SoftPlus.c generic/SparseLinear.c init.lua test/test.lua
-rw-r--r--PairwiseDistance.lua48
-rw-r--r--SparseJacobian.lua277
-rw-r--r--SparseLinear.lua17
-rw-r--r--generic/SoftPlus.c2
-rw-r--r--generic/SparseLinear.c102
-rw-r--r--init.lua1
-rw-r--r--test/test.lua124
7 files changed, 510 insertions, 61 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index d9e6f81..c8334d0 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -6,6 +6,7 @@ function PairwiseDistance:__init(p)
-- state
self.gradInput = {torch.Tensor(), torch.Tensor()}
self.output = torch.Tensor(1)
+ self.diff = torch.Tensor()
self.norm=p
end
@@ -17,8 +18,8 @@ function PairwiseDistance:updateOutput(input)
self.diff:resizeAs(input[1])
local diff = self.diff:zero()
- --local diff = torch.add(input[1], -1, input[2])
diff:add(input[1], -1, input[2])
+ diff:abs()
self.output:resize(input[1]:size(1))
self.output:zero()
@@ -27,7 +28,6 @@ function PairwiseDistance:updateOutput(input)
else
error('input must be vector or matrix')
end
-
return self.output
end
@@ -37,16 +37,39 @@ local function mathsign(x)
end
function PairwiseDistance:updateGradInput(input, gradOutput)
+ if input[1]:dim() > 2 then
+ error('input must be vector or matrix')
+ end
+
self.gradInput[1]:resize(input[1]:size())
self.gradInput[2]:resize(input[2]:size())
self.gradInput[1]:copy(input[1])
- self.gradInput[1]:add(-1, input[2])
+ self.gradInput[1]:add(-1, input[2])
+
if self.norm==1 then
self.gradInput[1]:apply(mathsign)
+ else
+ -- Note: derivative of p-norm:
+ -- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1)
+ if (self.norm > 2) then
+ self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2))
+ end
+
+ if (input[1]:dim() > 1) then
+ self.outExpand = self.outExpand or self.output.new()
+ self.outExpand:resize(self.output:size(1), 1)
+ self.outExpand:copy(self.output)
+ self.outExpand:add(1.0e-6) -- Prevent divide by zero errors
+ self.outExpand:pow(-(self.norm-1))
+ self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
+ self.gradInput[1]:size(2)))
+ else
+ self.gradInput[1]:mul(math.pow(self.output[1] + 1e-6, -(self.norm-1)))
+ end
end
if input[1]:dim() == 1 then
self.gradInput[1]:mul(gradOutput[1])
- elseif input[1]:dim() == 2 then
+ else
self.grad = self.grad or gradOutput.new()
self.ones = self.ones or gradOutput.new()
@@ -55,9 +78,22 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.grad:addr(gradOutput, self.ones)
self.gradInput[1]:cmul(self.grad)
- else
- error('input must be vector or matrix')
end
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end
+
+-- save away Module:type(type) for later use.
+PairwiseDistance._parent_type = parent.type
+
+-- Fix the bug where tmp = nn.PairwiseDistance:cuda() fails to convert table
+-- contents. We could, and probably should, change Module.lua to loop over
+-- and convert all the table elements in a module, but that might have
+-- repercussions, so this is a safer solution.
+function PairwiseDistance:type(type)
+ self:_parent_type(type) -- Call the parent (Module) type function
+ -- Now convert the left over table elements
+ self.gradInput[1] = self.gradInput[1]:type(type)
+ self.gradInput[2] = self.gradInput[2]:type(type)
+ return self
+end
diff --git a/SparseJacobian.lua b/SparseJacobian.lua
new file mode 100644
index 0000000..b778e67
--- /dev/null
+++ b/SparseJacobian.lua
@@ -0,0 +1,277 @@
+nn.SparseJacobian = {}
+
+function nn.SparseJacobian.backward (module, input, param, dparam)
+ local doparam = 0
+ if param then
+ doparam = 1
+ end
+
+ -- output deriv
+ module:forward(input)
+ local dout = module.output.new():resizeAs(module.output)
+ -- 1D view
+ local sdout = module.output.new(dout:storage(), 1, dout:nElement())
+ -- jacobian matrix to calculate
+ local jacobian
+ if doparam == 1 then
+ jacobian = torch.Tensor(param:nElement(), dout:nElement()):zero()
+ else
+ jacobian = torch.Tensor(input:size(1), dout:nElement()):zero()
+ end
+
+ for i=1,sdout:nElement() do
+ dout:zero()
+ sdout[i] = 1
+ module:zeroGradParameters()
+ local din = module:updateGradInput(input, dout)
+ module:accGradParameters(input, dout)
+ if doparam == 1 then
+ jacobian:select(2,i):copy(dparam)
+ else
+ jacobian:select(2,i):copy(din:select(2,2))
+ end
+ end
+
+ return jacobian
+end
+
+
+function nn.SparseJacobian.backwardUpdate (module, input, param)
+
+ -- output deriv
+ module:forward(input)
+ local dout = module.output.new():resizeAs(module.output)
+ -- 1D view
+ local sdout = module.output.new(dout:storage(),1,dout:nElement())
+ -- jacobian matrix to calculate
+ local jacobian = torch.Tensor(param:nElement(),dout:nElement()):zero()
+
+ -- original param
+ local params = module:parameters()
+ local origparams = {}
+ for j=1,#params do
+ table.insert(origparams, params[j]:clone())
+ end
+
+ for i=1,sdout:nElement() do
+ -- Reset parameters
+ for j=1,#params do
+ params[j]:copy(origparams[j])
+ end
+ dout:zero()
+ sdout[i] = 1
+ module:zeroGradParameters()
+ local din = module:updateGradInput(input, dout)
+ module:accUpdateGradParameters(input, dout, 1)
+ jacobian:select(2,i):copy(param)
+ end
+
+ for j=1,#params do
+ params[j]:copy(origparams[j])
+ end
+
+ return jacobian
+end
+
+function nn.SparseJacobian.forward(module, input, param)
+ local doparam = 0
+ if param then
+ doparam = 1
+ end
+ param = param or input
+
+ -- perturbation amount
+ local small = 1e-6
+ -- 1D view of input
+ --local tst = param:storage()
+ local sin
+ if doparam == 1 then
+ sin = param.new(param):resize(param:nElement())
+ else
+ sin = input.new(input):select(2,2)
+ end
+
+ local out = module:forward(input)
+ -- jacobian matrix to calculate
+ local jacobian
+ if doparam == 1 then
+ jacobian = torch.Tensor():resize(param:nElement(),
+ out:nElement())
+ else
+ jacobian = torch.Tensor():resize(input:size(1),
+ out:nElement())
+ end
+
+ local outa = torch.Tensor(jacobian:size(2))
+ local outb = torch.Tensor(jacobian:size(2))
+
+ for i=1,sin:nElement() do
+ sin[i] = sin[i] - small
+ outa:copy(module:forward(input))
+ sin[i] = sin[i] + 2*small
+ outb:copy(module:forward(input))
+ sin[i] = sin[i] - small
+
+ outb:add(-1,outa):div(2*small)
+ jacobian:select(1,i):copy(outb)
+ end
+
+ return jacobian
+end
+
+function nn.SparseJacobian.forwardUpdate(module, input, param)
+ -- perturbation amount
+ local small = 1e-6
+ -- 1D view of input
+ --local tst = param:storage()
+ local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size())
+ -- jacobian matrix to calculate
+ local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement())
+
+ local outa = torch.Tensor(jacobian:size(2))
+ local outb = torch.Tensor(jacobian:size(2))
+
+ for i=1,sin:nElement() do
+ sin[i] = sin[i] - small
+ outa:copy(module:forward(input))
+ sin[i] = sin[i] + 2*small
+ outb:copy(module:forward(input))
+ sin[i] = sin[i] - small
+
+ outb:add(-1,outa):div(2*small)
+ jacobian:select(1,i):copy(outb)
+ jacobian:select(1,i):mul(-1)
+ jacobian:select(1,i):add(sin[i])
+ end
+ return jacobian
+end
+
+function nn.SparseJacobian.testJacobian (module, input, minval, maxval)
+ minval = minval or -2
+ maxval = maxval or 2
+ local inrange = maxval - minval
+ input:select(2,2):copy(torch.rand(input:size(1)):mul(inrange):add(minval))
+ local jac_fprop = nn.SparseJacobian.forward(module,input)
+ local jac_bprop = nn.SparseJacobian.backward(module,input)
+ local error = jac_fprop-jac_bprop
+ return error:abs():max()
+end
+
+function nn.SparseJacobian.testJacobianParameters (module, input, param, dparam, minval, maxval)
+ minval = minval or -2
+ maxval = maxval or 2
+ local inrange = maxval - minval
+ input:select(2,2):copy(torch.rand(input:size(1)):mul(inrange):add(minval))
+ param:copy(torch.rand(param:nElement()):mul(inrange):add(minval))
+ local jac_bprop = nn.SparseJacobian.backward(module, input, param, dparam)
+ local jac_fprop = nn.SparseJacobian.forward(module, input, param)
+ local error = jac_fprop - jac_bprop
+ return error:abs():max()
+end
+
+function nn.SparseJacobian.testJacobianUpdateParameters (module, input, param, minval, maxval)
+ minval = minval or -2
+ maxval = maxval or 2
+ local inrange = maxval - minval
+ input:select(2,2):copy(torch.rand(input:size(1)):mul(inrange):add(minval))
+ param:copy(torch.rand(param:nElement()):mul(inrange):add(minval))
+ local params_bprop = nn.SparseJacobian.backwardUpdate(module, input, param)
+ local params_fprop = nn.SparseJacobian.forwardUpdate(module, input, param)
+
+ local error = params_fprop - params_bprop
+ return error:abs():max()
+end
+
+function nn.SparseJacobian.testIO(module,input, minval, maxval)
+ minval = minval or -2
+ maxval = maxval or 2
+ local inrange = maxval - minval
+
+ -- run module
+ module:forward(input)
+ local go = module.output:clone():copy(torch.rand(module.output:nElement()):mul(inrange):add(minval))
+ module:zeroGradParameters()
+ module:updateGradInput(input,go)
+ module:accGradParameters(input,go)
+
+ local fo = module.output:clone()
+ local bo = module.gradInput:clone()
+
+ -- write module
+ local f = torch.DiskFile('tmp.bin','w'):binary()
+ f:writeObject(module)
+ f:close()
+ -- read module
+ local m = torch.DiskFile('tmp.bin'):binary():readObject()
+ m:forward(input)
+ m:zeroGradParameters()
+ m:updateGradInput(input,go)
+ m:accGradParameters(input,go)
+ -- cleanup
+ os.remove('tmp.bin')
+
+ local fo2 = m.output:clone()
+ local bo2 = m.gradInput:clone()
+
+ local errf = fo - fo2
+ local errb = bo - bo2
+ return errf:abs():max(), errb:abs():max()
+end
+
+function nn.SparseJacobian.testAllUpdate(module, input, weight, gradWeight)
+ local gradOutput
+ local lr = torch.uniform(0.1, 1)
+ local errors = {}
+
+ -- accGradParameters
+ local maccgp = module:clone()
+ local weightc = maccgp[weight]:clone()
+ maccgp:forward(input)
+ gradOutput = torch.rand(maccgp.output:size())
+ maccgp:zeroGradParameters()
+ maccgp:updateGradInput(input, gradOutput)
+ maccgp:accGradParameters(input, gradOutput)
+ maccgp:updateParameters(lr)
+ errors["accGradParameters"] = (weightc-maccgp[gradWeight]*lr-maccgp[weight]):norm()
+
+ -- accUpdateGradParameters
+ local maccugp = module:clone()
+ maccugp:forward(input)
+ maccugp:updateGradInput(input, gradOutput)
+ maccugp:accUpdateGradParameters(input, gradOutput, lr)
+ errors["accUpdateGradParameters"] = (maccugp[weight]-maccgp[weight]):norm()
+
+ -- shared, accGradParameters
+ local macsh1 = module:clone()
+ local macsh2 = module:clone()
+ macsh2:share(macsh1, weight)
+ macsh1:forward(input)
+ macsh2:forward(input)
+ macsh1:zeroGradParameters()
+ macsh2:zeroGradParameters()
+ macsh1:updateGradInput(input, gradOutput)
+ macsh2:updateGradInput(input, gradOutput)
+ macsh1:accGradParameters(input, gradOutput)
+ macsh2:accGradParameters(input, gradOutput)
+ macsh1:updateParameters(lr)
+ macsh2:updateParameters(lr)
+ local err = (weightc-maccgp[gradWeight]*(lr*2)-macsh1[weight]):norm()
+ err = err + (weightc-maccgp[gradWeight]*(lr*2)-macsh2[weight]):norm()
+ errors["accGradParameters [shared]"] = err
+
+ -- shared, accUpdateGradParameters
+ local macshu1 = module:clone()
+ local macshu2 = module:clone()
+ macshu2:share(macshu1, weight)
+ macshu1:forward(input)
+ macshu2:forward(input)
+ macshu1:updateGradInput(input, gradOutput)
+ macshu2:updateGradInput(input, gradOutput)
+ macshu1:accUpdateGradParameters(input, gradOutput, lr)
+ macshu2:accUpdateGradParameters(input, gradOutput, lr)
+ local err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm()
+ err = err + (weightc-maccgp[gradWeight]*(lr*2)-macshu2[weight]):norm()
+ errors["accUpdateGradParameters [shared]"] = err
+
+ return errors
+end
diff --git a/SparseLinear.lua b/SparseLinear.lua
index f1a2be5..db3e540 100644
--- a/SparseLinear.lua
+++ b/SparseLinear.lua
@@ -42,3 +42,20 @@ end
function SparseLinear:accGradParameters(input, gradOutput, scale)
return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale)
end
+
+function SparseLinear:updateGradInput(input, gradOutput)
+ if self.gradInput then
+ self.gradInput:resize(input:size())
+ self.gradInput:copy(input)
+ local numNonzero = self.gradInput:size(1)
+ for e=1,numNonzero do
+ local g = 0
+ local i = self.gradInput[{e,1}]
+ for j=1,self.output:size(1) do
+ g = g + self.weight[{j,i}] * gradOutput[j]
+ end
+ self.gradInput[{e,2}] = g
+ end
+ return self.gradInput
+ end
+end
diff --git a/generic/SoftPlus.c b/generic/SoftPlus.c
index b4f62f7..49f50a7 100644
--- a/generic/SoftPlus.c
+++ b/generic/SoftPlus.c
@@ -10,7 +10,7 @@ static int nn_(SoftPlus_updateOutput)(lua_State *L)
THTensor_(resizeAs)(output, input);
TH_TENSOR_APPLY2(real, output, real, input, \
- *output_data = log1p(exp(*input_data));)
+ *output_data = THLog1p(exp(*input_data));)
return 1;
}
diff --git a/generic/SparseLinear.c b/generic/SparseLinear.c
index c602c2a..1cfcd45 100644
--- a/generic/SparseLinear.c
+++ b/generic/SparseLinear.c
@@ -9,25 +9,26 @@ static int nn_(SparseLinear_updateOutput)(lua_State *L)
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
THTensor * output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- long dim = weight->size[0]; /* number of weights.. */
+ long dim = weight->size[1]; /* number of weights.. */
THTensor_(copy)(output, bias);
- for(i = 0; i < input->size[1]; i++)
+ for(i = 0; i < input->size[0]; i++)
{
- long offset = (long)(THTensor_(get2d)(input, 0, i))-1;
-
+ long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
{
- real val = THTensor_(get2d)(input, 1, i);
- THBlas_(axpy)(output->size[0],
- val,
- THTensor_(data)(weight)+offset*weight->stride[0],
- weight->stride[1],
- THTensor_(data)(output),
- output->stride[0]);
+ real val = THTensor_(get2d)(input, i, 1);
+ THBlas_(axpy)(output->size[0],
+ val,
+ THTensor_(data)(weight)+offset*weight->stride[1],
+ weight->stride[0],
+ THTensor_(data)(output),
+ output->stride[0]);
+ }
+ else {
+ printf("\nupdateOutput: %ld not between 1 and %ld\n", offset+1, dim);
+ luaL_error(L, "index out of bound");
}
- else
- luaL_error(L, "index out of bound");
}
return 1;
}
@@ -43,32 +44,30 @@ static int nn_(SparseLinear_accGradParameters)(lua_State *L)
THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
real weightDecay = luaT_getfieldchecknumber(L, 1, "weightDecay");
- long dim = gradWeight->size[0]; /* number of weights.. */
+ long dim = gradWeight->size[1]; /* number of weights.. */
- for(i = 0; i < input->size[1]; i++)
+ for(i = 0; i < input->size[0]; i++)
{
- long offset = (long)(THTensor_(get2d)(input, 0, i))-1;
+ long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
- if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
- {
- real val = scale*THTensor_(get2d)(input, 1, i);
- THBlas_(scal)(gradOutput->size[0],
- 0,
- THTensor_(data)(gradWeight)+offset*gradWeight->stride[0],
- gradWeight->stride[1]); /* zero */
-
- THBlas_(axpy)(gradOutput->size[0],
- val,
- THTensor_(data)(gradOutput),
- gradOutput->stride[0],
- THTensor_(data)(gradWeight)+offset*gradWeight->stride[0],
- gradWeight->stride[1]);
- }
- else
- luaL_error(L, "index out of bound");
+ if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
+ {
+ real val = scale*THTensor_(get2d)(input, i, 1);
+
+ THBlas_(axpy)(gradOutput->size[0],
+ val,
+ THTensor_(data)(gradOutput),
+ gradOutput->stride[0],
+ THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
+ gradWeight->stride[0]);
+ }
+ else {
+ printf("\naccGradParameters: %ld not between 1 and %ld\n", offset+1, dim);
+ luaL_error(L, "index out of bound");
+ }
}
- THTensor_(cadd)(gradBias, gradBias, 1, gradOutput);
+ THTensor_(cadd)(gradBias, gradBias, scale, gradOutput);
if(weightDecay != 0)
THTensor_(cadd)(gradWeight, gradWeight, weightDecay, weight);
@@ -87,26 +86,27 @@ int nn_(SparseLinear_updateParameters)(lua_State *L)
THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
- THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
-
- long dim = weight->size[0]; /* number of weights.. */
+ THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
+ long dim = weight->size[1]; /* number of weights.. */
THTensor_(cadd)(bias, bias, -learningRate, gradBias);
- for(i = 0; i < lastInput->size[1]; i++)
+ for(i = 0; i < lastInput->size[0]; i++)
{
- long offset = (long)(THTensor_(get2d)(lastInput, 0, i))-1;
-
- if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
- {
- THBlas_(axpy)(bias->size[0],
- -learningRate,
- THTensor_(data)(gradWeight)+offset*gradWeight->stride[0],
- gradWeight->stride[1],
- THTensor_(data)(weight)+offset*weight->stride[0],
- weight->stride[1]);
- }
- else
- luaL_error(L, "index out of bound");
+ long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1;
+
+ if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
+ {
+ THBlas_(axpy)(bias->size[0],
+ -learningRate,
+ THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
+ gradWeight->stride[0],
+ THTensor_(data)(weight)+offset*weight->stride[1],
+ weight->stride[0]);
+ }
+ else {
+ printf("\nUpdateParameters: %ld not between 1 and %ld\n", offset+1, dim);
+ luaL_error(L, "index out of bound");
+ }
}
return 0;
}
diff --git a/init.lua b/init.lua
index 87f080e..7071298 100644
--- a/init.lua
+++ b/init.lua
@@ -103,5 +103,6 @@ include('WeightedMSECriterion.lua')
include('StochasticGradient.lua')
include('Jacobian.lua')
+include('SparseJacobian.lua')
include('hessian.lua')
include('test.lua')
diff --git a/test/test.lua b/test/test.lua
index dd6be22..df80616 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2,6 +2,7 @@ require 'torch'
local mytester = torch.Tester()
local jac
+local sjac
local precision = 1e-5
local expprecision = 1e-4
@@ -237,8 +238,8 @@ function nntest.Sqrt()
end
function nntest.Linear()
- local ini = math.random(50,70)
- local inj = math.random(50,70)
+ local ini = math.random(5,7)
+ local inj = math.random(5,7)
local input = torch.Tensor(ini):zero()
local module = nn.Linear(ini,inj)
@@ -303,6 +304,68 @@ function nntest.Linear()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.SparseLinear()
+ local ini = math.random(5000,10000)
+ local inj = math.random(50,100)
+ local numNonzero = math.random(5,20)
+
+ local module = nn.SparseLinear(ini,inj)
+
+ -- Create a random sparse vector
+ N = {}
+ for i = 1, ini do N[i] = i end
+ for i = 1, numNonzero do
+ local j = math.random(i,ini)
+ N[i], N[j] = N[j], N[i]
+ end
+ local input = torch.Tensor(numNonzero, 2):zero()
+ for i = 1, numNonzero do input[{i,1}] = N[i] end
+ local values = input:select(2,2)
+ values:copy(torch.rand(values:nElement())):mul(2):add(-1)
+
+ -- Check output
+ local actual = module:forward(input)
+ local expected = torch.Tensor(inj)
+ for j = 1, inj do
+ expected[j] = 0
+ for i = 1,numNonzero do
+ expected[j] = expected[j] + values[i] * module.weight[{j, N[i]}]
+ end
+ end
+ local err = (expected - actual):abs():max()
+ mytester:assertle(err, precision, 'error on result')
+
+ -- Jacobian 1D
+ local err = sjac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local err = sjac.testJacobianParameters(module, input, module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = sjac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ mytester:assertlt(err,precision, 'error on bias ')
+
+ local err = sjac.testJacobianUpdateParameters(module, input, module.weight)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ local err = sjac.testJacobianUpdateParameters(module, input, module.bias)
+ mytester:assertlt(err,precision, 'error on bias [direct update] ')
+
+ for t,err in pairs(sjac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on weight [%s]', t))
+ end
+
+ for t,err in pairs(sjac.testAllUpdate(module, input, 'bias', 'gradBias')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on bias [%s]', t))
+ end
+
+ local ferr, berr = sjac.testIO(module, input)
+ mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
+end
+
function nntest.Euclidean()
local ini = math.random(50,70)
local inj = math.random(50,70)
@@ -347,7 +410,6 @@ end
-- local weight = torch.randn(from)
-- local cri = nn.WeightedMSECriterion(weight)
-- local module = nn.CriterionModule(cri,target)
-
-- local err = jac.testJacobian(module, input)
-- mytester:assertlt(err, precision, 'error on state ')
@@ -1530,14 +1592,70 @@ function nntest.Module_getParameters_7()
mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector')
end
+function nntest.PairwiseDistance()
+ -- Note: testJacobian doesn't support table inputs, and rather than re-write
+ -- it so that it does, I'll just use a split table module on the input.
+ -- I assume both SplitTable and Sequential do not have bugs, otherwise this
+ -- test will break.
+ for p = 1,4 do -- test a few Lp norms
+ -- TEST CASE 1: non-batch input, same code path but includes a resize
+ local ini = math.random(10,20)
+ local input = torch.Tensor(2, ini):zero()
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.PairwiseDistance(p))
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, ' error on state ')
+
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ')
+
+ -- Also check that the forward prop result is correct.
+ input = torch.rand(2, ini)
+ err = torch.dist(input:select(1,1), input:select(1,2), p) -
+ module:forward(input)[1]
+ mytester:assertlt(err,precision, ' error on non-batch fprop ')
+
+ -- TEST CASE 2: batch input
+ local inj = math.random(10,20)
+ input = torch.Tensor(2, inj, ini):zero()
+
+ -- (Rebuild the module to avoid correlated tests)
+ module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.PairwiseDistance(p))
+
+ err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, ' error on state ')
+
+ -- Also check that the forward prop result is correct.
+ -- manually calculate each distance separately
+ local inputa = torch.rand(inj,ini)
+ local inputb = torch.rand(inj,ini)
+ local dist_manual = torch.Tensor(inj)
+ for i=1, inputa:size(1) do
+ dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p)
+ end
+ -- compare the distances to the module's fprop
+ local dist = module:forward(torch.cat(inputa,inputb,1):resize(2,inj,ini))
+ err = dist - dist_manual
+ mytester:assertlt(err:norm(), precision, torch.typename(module) ..
+ ' error on batch fprop ')
+ end
+end
+
mytester:add(nntest)
if not nn then
require 'nn'
jac = nn.Jacobian
+ sjac = nn.SparseJacobian
mytester:run()
else
jac = nn.Jacobian
+ sjac = nn.SparseJacobian
function nn.test(tests)
-- randomize stuff
math.randomseed(os.time())