Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-19 04:38:31 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-19 04:38:31 +0400
commitd12fd6b3c516430f5d7f1eabbbca19565346be40 (patch)
tree4275ba2e94e67d7fa1469987aaa56bc2293e03e9 /test
parent8cf1aa463a4e23ea6946a1856bf474ce4017eec6 (diff)
Balance
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua31
1 files changed, 31 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 3d04c4e..33af2b7 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -478,6 +478,37 @@ function nnxtest.SoftMaxTree()
mytester:assertTensorEq(bias3, bias, 0.000001)
end
+local function blur(mean, stdv, size)
+ local range = torch.range(1,size):float()
+ local a = 1/(stdv*math.sqrt(2*math.pi))
+ local b = -1/(2*stdv*stdv)
+ return range:add(-mean):pow(2):mul(b):exp():mul(a)
+end
+
+function nnxtest.Balance()
+ local inputSize = 7
+ local batchSize = 3
+ local nBatch = 1
+
+ local input = torch.randn(batchSize, inputSize):mul(0.1):float()
+ for i=1,batchSize do
+ input[i]:add(blur(3, 1, inputSize):float())
+ end
+ local sm = nn.SoftMax()
+ sm:float()
+ input = sm:forward(input)
+ local gradOutput = torch.randn(batchSize, inputSize):float()
+ local bl = nn.Balance(nBatch)
+ bl:float()
+
+ local output = bl:forward(input)
+ local p_y = output:sum(1):div(output:sum())
+ mytester:assert(p_y:std() < 0.02)
+ mytester:assert(math.abs(p_y:sum() - 1) < 0.000001)
+
+ local gradInput = bl:backward(input, gradOutput)
+end
+
function nnx.test(tests)
xlua.require('image',true)
mytester = torch.Tester()