diff options
author | Ivo Danihelka <ivo@danihelka.net> | 2014-08-19 14:55:44 +0400 |
---|---|---|
committer | Ivo Danihelka <ivo@danihelka.net> | 2014-08-19 14:55:44 +0400 |
commit | 8d23e818135cf748851bc133b70c136b038adb58 (patch) | |
tree | 4e4e25e5df754385c32e2a3991b7039b739f6036 /test | |
parent | aa18c943807fb1d071fe66cc8e319736f7ede7d5 (diff) |
Added nn.ModuleFromCriterion to wrap a called Criterion.
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ModuleFromCriterion.lua | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/test/test_ModuleFromCriterion.lua b/test/test_ModuleFromCriterion.lua new file mode 100644 index 0000000..78d3cd2 --- /dev/null +++ b/test/test_ModuleFromCriterion.lua @@ -0,0 +1,57 @@ + +require 'totem' +require 'nngraph' +local test = {} +local tester = totem.Tester() + +function test.test_call() + local prediction = nn.Identity()() + local target = nn.Identity()() + local mse = nn.MSECriterion()({prediction, target}) + local costBits = nn.MulConstant(1/math.log(2))(mse) + local net = nn.gModule({prediction, target}, {costBits}) + + local input = {torch.randn(3, 5), torch.rand(3, 5)} + local criterion = nn.MSECriterion() + local output = net:forward(input) + criterion:forward(input[1], input[2]) + tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14) + + local gradOutput = torch.randn(1) + local gradInput = net:backward(input, gradOutput) + criterion:backward(input[1], input[2]) + tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14) + tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget") +end + +function test.test_grad() + local prediction = nn.Identity()() + local zero = nn.MulConstant(0)(prediction) + -- The target is created inside of the nngraph + -- to ignore the zero gradTarget. + local target = nn.AddConstant(1.23)(zero) + local mse = nn.MSECriterion()({prediction, target}) + local net = nn.gModule({prediction}, {mse}) + + local input = torch.randn(4, 7) + totem.nn.checkGradients(tester, net, input) +end + +local function module() + local module = nn.ModuleFromCriterion(nn.MSECriterion()) + local input = {torch.randn(3, 5), torch.randn(3, 5)} + return module, input +end + +function test.test_serializable() + local module, input = module() + totem.nn.checkSerializable(tester, module, input) +end + +function test.test_typeCastable() + local module, input = module() + totem.nn.checkTypeCastable(tester, module, input) +end + + +tester:add(test):run() |