Welcome to mirror list, hosted at ThFree Co, Russian Federation.

test_ModuleFromCriterion.lua « test - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 9206905636841b0ab4bac84ff3c49437be000ba4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

require 'totem'
require 'nngraph'
local test = {}
local tester = totem.Tester()

function test.test_call()
   local prediction = nn.Identity()()
   local target = nn.Identity()()
   local mse = nn.MSECriterion()({prediction, target})
   local costBits = nn.MulConstant(1/math.log(2))(mse)
   local net = nn.gModule({prediction, target}, {costBits})

   local input = {torch.randn(3, 5), torch.rand(3, 5)}
   local criterion = nn.MSECriterion()
   local output = net:forward(input)
   criterion:forward(input[1], input[2])
   tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14)

   local gradOutput = torch.randn(1)
   local gradInput = net:backward(input, gradOutput)
   criterion:backward(input[1], input[2])
   tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14)
   tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget")
end

function test.test_grad()
   local prediction = nn.Identity()()
   local zero = nn.MulConstant(0)(prediction)
   -- The target is created inside of the nngraph
   -- to ignore the zero gradTarget.
   local target = nn.AddConstant(1.23)(zero)
   local mse = nn.MSECriterion()({prediction, target})
   local net = nn.gModule({prediction}, {mse})

   local input = torch.randn(4, 7)
   totem.nn.checkGradients(tester, net, input)
end

local function module()
   local module = nn.ModuleFromCriterion(nn.MSECriterion())
   local input = {torch.randn(3, 5), torch.randn(3, 5)}
   return module, input
end

function test.test_serializable()
   local module, input = module()
   totem.nn.checkSerializable(tester, module, input)
end

function test.test_typeCastable()
   local module, input = module()
   totem.nn.checkTypeCastable(tester, module, input)
end


tester:add(test):run()