Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rwxr-xr-xtest.lua18
1 files changed, 18 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index dbac512..4e3f627 100755
--- a/test.lua
+++ b/test.lua
@@ -2175,7 +2175,25 @@ function nntest.MarginRankingCriterion()
local v = torch.rand(2, batch_size)
local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1)
criterionJacobianTest1DTable(crit,v,t)
+end
+
+function nntest.ModuleCriterion()
+ local input = torch.randn(8,4)
+ local target = torch.randn(8,4)
+ local inputModule = nn.Tanh()
+ local criterion = nn.MSECriterion()
+ local mc = nn.ModuleCriterion(criterion, inputModule)
+
+ local err = mc:forward(input, target)
+ local gradInput = mc:backward(input, target)
+
+ local output = inputModule:forward(input)
+ local err2 = criterion:forward(output, target)
+ local gradOutput = criterion:backward(output, target)
+ local gradInput2 = inputModule:backward(input, gradOutput)
+ mytester:assert(err == err2, "ModuleCriterion backward err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "ModuleCriterion backward err")
end
function nntest.MaskedSelect()