diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-19 04:38:31 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-19 04:38:31 +0400 |
commit | d12fd6b3c516430f5d7f1eabbbca19565346be40 (patch) | |
tree | 4275ba2e94e67d7fa1469987aaa56bc2293e03e9 /test | |
parent | 8cf1aa463a4e23ea6946a1856bf474ce4017eec6 (diff) |
Balance
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 3d04c4e..33af2b7 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -478,6 +478,37 @@ function nnxtest.SoftMaxTree() mytester:assertTensorEq(bias3, bias, 0.000001) end +local function blur(mean, stdv, size) + local range = torch.range(1,size):float() + local a = 1/(stdv*math.sqrt(2*math.pi)) + local b = -1/(2*stdv*stdv) + return range:add(-mean):pow(2):mul(b):exp():mul(a) +end + +function nnxtest.Balance() + local inputSize = 7 + local batchSize = 3 + local nBatch = 1 + + local input = torch.randn(batchSize, inputSize):mul(0.1):float() + for i=1,batchSize do + input[i]:add(blur(3, 1, inputSize):float()) + end + local sm = nn.SoftMax() + sm:float() + input = sm:forward(input) + local gradOutput = torch.randn(batchSize, inputSize):float() + local bl = nn.Balance(nBatch) + bl:float() + + local output = bl:forward(input) + local p_y = output:sum(1):div(output:sum()) + mytester:assert(p_y:std() < 0.02) + mytester:assert(math.abs(p_y:sum() - 1) < 0.000001) + + local gradInput = bl:backward(input, gradOutput) +end + function nnx.test(tests) xlua.require('image',true) mytester = torch.Tester() |