diff options
Diffstat (limited to 'test.lua')
-rwxr-xr-x | test.lua | 18 |
1 files changed, 18 insertions, 0 deletions
@@ -2175,7 +2175,25 @@ function nntest.MarginRankingCriterion() local v = torch.rand(2, batch_size) local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1) criterionJacobianTest1DTable(crit,v,t) +end + +function nntest.ModuleCriterion() + local input = torch.randn(8,4) + local target = torch.randn(8,4) + local inputModule = nn.Tanh() + local criterion = nn.MSECriterion() + local mc = nn.ModuleCriterion(criterion, inputModule) + + local err = mc:forward(input, target) + local gradInput = mc:backward(input, target) + + local output = inputModule:forward(input) + local err2 = criterion:forward(output, target) + local gradOutput = criterion:backward(output, target) + local gradInput2 = inputModule:backward(input, gradOutput) + mytester:assert(err == err2, "ModuleCriterion backward err") + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "ModuleCriterion backward err") end function nntest.MaskedSelect() |