Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorIvo Danihelka <ivo@danihelka.net>2014-08-19 14:55:44 +0400
committerIvo Danihelka <ivo@danihelka.net>2014-08-19 14:55:44 +0400
commit8d23e818135cf748851bc133b70c136b038adb58 (patch)
tree4e4e25e5df754385c32e2a3991b7039b739f6036 /test
parentaa18c943807fb1d071fe66cc8e319736f7ede7d5 (diff)
Added nn.ModuleFromCriterion to wrap a called Criterion.
Diffstat (limited to 'test')
-rw-r--r--test/test_ModuleFromCriterion.lua57
1 files changed, 57 insertions, 0 deletions
diff --git a/test/test_ModuleFromCriterion.lua b/test/test_ModuleFromCriterion.lua
new file mode 100644
index 0000000..78d3cd2
--- /dev/null
+++ b/test/test_ModuleFromCriterion.lua
@@ -0,0 +1,57 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+function test.test_call()
+ local prediction = nn.Identity()()
+ local target = nn.Identity()()
+ local mse = nn.MSECriterion()({prediction, target})
+ local costBits = nn.MulConstant(1/math.log(2))(mse)
+ local net = nn.gModule({prediction, target}, {costBits})
+
+ local input = {torch.randn(3, 5), torch.rand(3, 5)}
+ local criterion = nn.MSECriterion()
+ local output = net:forward(input)
+ criterion:forward(input[1], input[2])
+ tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14)
+
+ local gradOutput = torch.randn(1)
+ local gradInput = net:backward(input, gradOutput)
+ criterion:backward(input[1], input[2])
+ tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14)
+ tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget")
+end
+
+function test.test_grad()
+ local prediction = nn.Identity()()
+ local zero = nn.MulConstant(0)(prediction)
+ -- The target is created inside of the nngraph
+ -- to ignore the zero gradTarget.
+ local target = nn.AddConstant(1.23)(zero)
+ local mse = nn.MSECriterion()({prediction, target})
+ local net = nn.gModule({prediction}, {mse})
+
+ local input = torch.randn(4, 7)
+ totem.nn.checkGradients(tester, net, input)
+end
+
+local function module()
+ local module = nn.ModuleFromCriterion(nn.MSECriterion())
+ local input = {torch.randn(3, 5), torch.randn(3, 5)}
+ return module, input
+end
+
+function test.test_serializable()
+ local module, input = module()
+ totem.nn.checkSerializable(tester, module, input)
+end
+
+function test.test_typeCastable()
+ local module, input = module()
+ totem.nn.checkTypeCastable(tester, module, input)
+end
+
+
+tester:add(test):run()